1818# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919import os
2020import logging
21- from typing import Any , List
21+ from typing import Any , List , Optional , Dict
2222from concurrent import futures
2323
2424from . import _types
3232from sqlalchemy .sql .expression import TextClause
3333
3434from google .cloud import firestore_admin_v1
35- from google .cloud .firestore_admin_v1 .types import ListDatabasesResponse
36- from google .api_core .exceptions import GoogleAPIError
37- from google .api_core import client_info
35+ from google .cloud .firestore_admin_v1 .types import ListDatabasesResponse , Database
36+ from google .oauth2 import service_account
3837
3938logger = logging .getLogger ('sqlalchemy.dialects.CloudDatastore' )
4039
@@ -354,6 +353,7 @@ def __init__(
354353 self .dataset_id = None
355354 self .list_tables_page_size = list_tables_page_size
356355 self ._client = None
356+ self .credentials = None
357357
358358 @classmethod
359359 def dbapi (cls ):
@@ -363,7 +363,8 @@ def dbapi(cls):
363363 def do_ping (self , dbapi_connection ):
364364 """Performs a simple operation to check if the connection is still alive."""
365365 try :
366- # Basic connectivity check
366+ query = self ._client .query (kind = "__kind__" )
367+ query .fetch (limit = 1 , timeout = 30 )
367368 return True
368369 except Exception :
369370 return False
@@ -406,11 +407,12 @@ def create_connect_args(self, url):
406407 if user_supplied_client :
407408 return ([], {})
408409 else :
409- client = create_datastore_client (
410+ client , credentials = create_datastore_client (
410411 credentials_path = self .credentials_path ,
411412 credentials_info = self .credentials_info ,
412413 credentials_base64 = self .credentials_base64 ,
413414 project_id = self .billing_project_id ,
415+ database = None ,
414416 )
415417 self .project_id = self .project_id if self .project_id else client .project
416418 self .billing_project_id = (
@@ -423,13 +425,14 @@ def create_connect_args(self, url):
423425 )
424426
425427 self ._client = client
428+ self .credentials = credentials
426429 setattr (self ._client , "credentials_path" , self .credentials_path )
427430 setattr (self ._client , "credentials_info" , self .credentials_info )
428431 setattr (self ._client , "credentials_base64" , self .credentials_base64 )
429432 return ([], {"client" : client })
430433
431- def get_schema_names (self , connection , ** kw ):
432- return self ._list_datastore_databases (self .credentials_info , self .project_id )
434+ def get_schema_names (self , connection : Connection , ** kw ) -> Optional [ List [ str ]] :
435+ return self ._list_datastore_databases (self .credentials , self .project_id )
433436
434437 def _list_datastore_databases (self , cred : service_account .Credentials , project_id : str ) -> Optional [List [str ]]:
435438 """Lists all Datastore databases for a given Google Cloud project.
@@ -454,8 +457,13 @@ def get_database_short_name(database: Database) -> Optional[List[str]]:
454457 logging .error (e )
455458 return []
456459
457- def get_table_names (self , connection , schema : str | None = None , ** kw ):
458- client = self ._client
460+ def get_table_names (self , connection : Connection , schema : str | None = None , ** kw ) -> Optional [List [str ]]:
461+ client , _ = create_datastore_client (
462+ credentials_path = self .credentials_path ,
463+ credentials_info = self .credentials_info ,
464+ credentials_base64 = self .credentials_base64 ,
465+ database = schema
466+ )
459467 query = client .query (kind = "__kind__" )
460468 kinds = list (query .fetch ())
461469
@@ -469,7 +477,12 @@ def get_kind_name(kind):
469477
470478 def get_columns (self , connection : Connection , table_name : str , schema : str | None = None , ** kw ):
471479 """Retrieve column information from the database with optimized parallel processing."""
472- client = self ._client
480+ client , _ = create_datastore_client (
481+ credentials_path = self .credentials_path ,
482+ credentials_info = self .credentials_info ,
483+ credentials_base64 = self .credentials_base64 ,
484+ database = schema
485+ )
473486 ancestor_key = client .key ("__kind__" , table_name )
474487 query = client .query (kind = "__property__" , ancestor = ancestor_key )
475488 properties = list (query .fetch ())
0 commit comments