Skip to content

Commit 4327ecd

Browse files
committed
Add schema related to get_table_names and get_columns
1 parent 501150d commit 4327ecd

2 files changed

Lines changed: 30 additions & 15 deletions

File tree

sqlalchemy_datastore/_helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
import re
2222
import os
23-
from typing import Optional
23+
from typing import Optional, Tuple
2424

2525
from google.api_core import client_info
2626
import google.auth
@@ -58,7 +58,8 @@ def create_datastore_client(
5858
credentials_base64: Optional[str] = None,
5959
project_id: Optional[str] = None,
6060
user_agent: Optional[client_info.ClientInfo] = None,
61-
) -> datastore.Client:
61+
database: Optional[str] = None
62+
) -> Tuple[datastore.Client, service_account.Credentials]:
6263
"""Construct a BigQuery client object.
6364
6465
Args:
@@ -77,7 +78,7 @@ def create_datastore_client(
7778
"""
7879

7980
default_project = None
80-
81+
database = database if database != "(default)" else None
8182
if os.getenv("DATASTORE_EMULATOR_HOST") is not None:
8283
client = datastore.Client(project=project_id)
8384
return client
@@ -109,7 +110,8 @@ def create_datastore_client(
109110
client_info=info,
110111
project=project_id,
111112
credentials=credentials,
112-
)
113+
database=database
114+
), credentials
113115

114116

115117
def substitute_re_method(r, flags=0, repl=None):

sqlalchemy_datastore/base.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919
import os
2020
import logging
21-
from typing import Any, List
21+
from typing import Any, List, Optional, Dict
2222
from concurrent import futures
2323

2424
from . import _types
@@ -32,9 +32,8 @@
3232
from sqlalchemy.sql.expression import TextClause
3333

3434
from google.cloud import firestore_admin_v1
35-
from google.cloud.firestore_admin_v1.types import ListDatabasesResponse
36-
from google.api_core.exceptions import GoogleAPIError
37-
from google.api_core import client_info
35+
from google.cloud.firestore_admin_v1.types import ListDatabasesResponse, Database
36+
from google.oauth2 import service_account
3837

3938
logger = logging.getLogger('sqlalchemy.dialects.CloudDatastore')
4039

@@ -354,6 +353,7 @@ def __init__(
354353
self.dataset_id = None
355354
self.list_tables_page_size = list_tables_page_size
356355
self._client = None
356+
self.credentials = None
357357

358358
@classmethod
359359
def dbapi(cls):
@@ -363,7 +363,8 @@ def dbapi(cls):
363363
def do_ping(self, dbapi_connection):
364364
"""Performs a simple operation to check if the connection is still alive."""
365365
try:
366-
# Basic connectivity check
366+
query = self._client.query(kind="__kind__")
367+
query.fetch(limit=1, timeout=30)
367368
return True
368369
except Exception:
369370
return False
@@ -406,11 +407,12 @@ def create_connect_args(self, url):
406407
if user_supplied_client:
407408
return ([], {})
408409
else:
409-
client = create_datastore_client(
410+
client, credentials = create_datastore_client(
410411
credentials_path=self.credentials_path,
411412
credentials_info=self.credentials_info,
412413
credentials_base64=self.credentials_base64,
413414
project_id=self.billing_project_id,
415+
database=None,
414416
)
415417
self.project_id = self.project_id if self.project_id else client.project
416418
self.billing_project_id = (
@@ -423,13 +425,14 @@ def create_connect_args(self, url):
423425
)
424426

425427
self._client = client
428+
self.credentials = credentials
426429
setattr(self._client, "credentials_path", self.credentials_path)
427430
setattr(self._client, "credentials_info", self.credentials_info)
428431
setattr(self._client, "credentials_base64", self.credentials_base64)
429432
return ([], {"client": client})
430433

431-
def get_schema_names(self, connection, **kw):
432-
return self._list_datastore_databases(self.credentials_info, self.project_id)
434+
def get_schema_names(self, connection: Connection, **kw) -> Optional[List[str]]:
435+
return self._list_datastore_databases(self.credentials, self.project_id)
433436

434437
def _list_datastore_databases(self, cred: service_account.Credentials, project_id: str) -> Optional[List[str]]:
435438
"""Lists all Datastore databases for a given Google Cloud project.
@@ -454,8 +457,13 @@ def get_database_short_name(database: Database) -> Optional[List[str]]:
454457
logging.error(e)
455458
return []
456459

457-
def get_table_names(self, connection, schema: str | None = None, **kw):
458-
client = self._client
460+
def get_table_names(self, connection: Connection, schema: str | None = None, **kw) -> Optional[List[str]]:
461+
client, _ = create_datastore_client(
462+
credentials_path=self.credentials_path,
463+
credentials_info=self.credentials_info,
464+
credentials_base64=self.credentials_base64,
465+
database=schema
466+
)
459467
query = client.query(kind="__kind__")
460468
kinds = list(query.fetch())
461469

@@ -469,7 +477,12 @@ def get_kind_name(kind):
469477

470478
def get_columns(self, connection: Connection, table_name: str, schema: str | None = None, **kw):
471479
"""Retrieve column information from the database with optimized parallel processing."""
472-
client = self._client
480+
client, _ = create_datastore_client(
481+
credentials_path=self.credentials_path,
482+
credentials_info=self.credentials_info,
483+
credentials_base64=self.credentials_base64,
484+
database=schema
485+
)
473486
ancestor_key = client.key("__kind__", table_name)
474487
query = client.query(kind="__property__", ancestor=ancestor_key)
475488
properties = list(query.fetch())

0 commit comments

Comments
 (0)