Skip to content
This repository was archived by the owner on Mar 13, 2026. It is now read-only.

Commit 830109d

Browse files
committed
Support Named Schemas
1 parent e17c5ef commit 830109d

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler):
233233

234234
compound_keywords = _compound_keywords
235235

236+
def __init__(self, *args, **kwargs):
237+
self.tablealiases = {}
238+
super().__init__(*args, **kwargs)
239+
236240
def get_from_hint_text(self, _, text):
237241
"""Return a hint text.
238242
@@ -378,8 +382,10 @@ def limit_clause(self, select, **kw):
378382
return text
379383

380384
def returning_clause(self, stmt, returning_cols, **kw):
385+
# Set include_table=False because although table names are allowed in
386+
# RETURNING clauses, schema names are not.
381387
columns = [
382-
self._label_select_column(None, c, True, False, {})
388+
self._label_select_column(None, c, True, False, {}, include_table=False)
383389
for c in expression._select_iterables(returning_cols)
384390
]
385391

@@ -391,6 +397,87 @@ def visit_sequence(self, seq, **kw):
391397
seq
392398
)
393399

400+
def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs):
401+
"""Produces the table name.
402+
403+
Schema names are not allowed in Spanner SELECT statements. We
404+
need to avoid generating SQL like
405+
406+
SELECT schema.tbl.id
407+
FROM schema.tbl
408+
409+
To do so, we alias the table in order to produce SQL like:
410+
411+
SELECT tbl_1.id, tbl_1.col
412+
FROM schema.tbl AS tbl_1
413+
414+
And do similar for UPDATE and DELETE statements.
415+
416+
We don't need to correct INSERT statements, which is fortunate
417+
because INSERT statements actually do not currently result in
418+
calls to `visit_table`.
419+
420+
This closely mirrors the mssql dialect which also avoids
421+
schema-qualified columns in SELECTs, although the behaviour is
422+
currently behind a deprecated 'legacy_schema_aliasing' flag.
423+
"""
424+
if spanner_aliased is table or self.isinsert:
425+
return super().visit_table(table, **kwargs)
426+
427+
# alias schema-qualified tables
428+
alias = self._schema_aliased_table(table)
429+
if alias is not None:
430+
return self.process(alias, spanner_aliased=table, **kwargs)
431+
else:
432+
return super().visit_table(table, **kwargs)
433+
434+
def visit_alias(self, alias, **kw):
435+
"""Produces alias statements."""
436+
# translate for schema-qualified table aliases
437+
kw["spanner_aliased"] = alias.element
438+
return super().visit_alias(alias, **kw)
439+
440+
def visit_column(self, column, add_to_result_map=None, **kw):
441+
"""Produces column expressions.
442+
443+
In tandem with visit_table, replaces schema-qualified column
444+
names with column names qualified against an alias.
445+
"""
446+
if (
447+
column.table is not None
448+
and not self.isinsert
449+
or self.is_subquery()
450+
):
451+
# translate for schema-qualified table aliases
452+
t = self._schema_aliased_table(column.table)
453+
if t is not None:
454+
converted = elements._corresponding_column_or_error(t, column)
455+
if add_to_result_map is not None:
456+
add_to_result_map(
457+
column.name,
458+
column.name,
459+
(column, column.name, column.key),
460+
column.type,
461+
)
462+
463+
return super().visit_column(converted, **kw)
464+
465+
return super().visit_column(column, add_to_result_map=add_to_result_map, **kw)
466+
467+
def _schema_aliased_table(self, table):
468+
"""Creates an alias for the table if it is schema-qualified.
469+
470+
If the table is schema-qualified, returns an alias for the
471+
table and caches the alias for future references to the
472+
table. If the table is not schema-qualified, returns None.
473+
"""
474+
if getattr(table, "schema", None) is not None:
475+
if table not in self.tablealiases:
476+
self.tablealiases[table] = table.alias()
477+
return self.tablealiases[table]
478+
else:
479+
return None
480+
394481

395482
class SpannerDDLCompiler(DDLCompiler):
396483
"""Spanner DDL statements compiler."""

test/system/test_basics.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
Boolean,
2626
BIGINT,
2727
select,
28+
update,
29+
delete,
2830
)
2931
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
3032
from sqlalchemy.types import REAL
@@ -58,6 +60,16 @@ def define_tables(cls, metadata):
5860
Column("name", String(20)),
5961
)
6062

63+
with cls.bind.begin() as conn:
64+
conn.execute(text("CREATE SCHEMA IF NOT EXISTS schema"))
65+
Table(
66+
"users",
67+
metadata,
68+
Column("ID", Integer, primary_key=True),
69+
Column("name", String(20)),
70+
schema="schema",
71+
)
72+
6173
def test_hello_world(self, connection):
6274
greeting = connection.execute(text("select 'Hello World'"))
6375
eq_("Hello World", greeting.fetchone()[0])
@@ -139,6 +151,12 @@ class User(Base):
139151
ID: Mapped[int] = mapped_column(primary_key=True)
140152
name: Mapped[str] = mapped_column(String(20))
141153

154+
class SchemaUser(Base):
155+
__tablename__ = "users"
156+
__table_args__ = {"schema": "schema"}
157+
ID: Mapped[int] = mapped_column(primary_key=True)
158+
name: Mapped[str] = mapped_column(String(20))
159+
142160
engine = connection.engine
143161
with Session(engine) as session:
144162
number = Number(
@@ -156,3 +174,25 @@ class User(Base):
156174
users = session.scalars(statement).all()
157175
eq_(1, len(users))
158176
is_true(users[0].ID > 0)
177+
178+
with Session(engine) as session:
179+
user = SchemaUser(name="SchemaTest")
180+
session.add(user)
181+
session.commit()
182+
183+
users = session.scalars(select(SchemaUser).where(SchemaUser.name == "SchemaTest")).all()
184+
eq_(1, len(users))
185+
is_true(users[0].ID > 0)
186+
187+
session.execute(update(SchemaUser).where(SchemaUser.name == "SchemaTest").values(name="NewName"))
188+
session.commit()
189+
190+
users = session.scalars(select(SchemaUser).where(SchemaUser.name == "NewName")).all()
191+
eq_(1, len(users))
192+
is_true(users[0].ID > 0)
193+
194+
session.execute(delete(SchemaUser).where(SchemaUser.name=="NewName"))
195+
session.commit()
196+
197+
users = session.scalars(select(SchemaUser).where(SchemaUser.name=="NewName")).all()
198+
eq_(0, len(users))

0 commit comments

Comments
 (0)