Skip to content

Commit ee344d9

Browse files
feat: discover and load downstream schemas for cascade and drop
Previously, cascade delete and drop only traversed tables in explicitly activated schemas. If a dependent table lived in an unactivated schema (common in multi-schema pipelines), it was invisible to the dependency graph, causing FK errors at delete time. New Dependencies.load_all_downstream() method iteratively discovers schemas that reference the loaded schemas via FK relationships, expanding the dependency graph until all downstream schemas are included. Uses information_schema (MySQL) and pg_constraint (PostgreSQL) to find cross-schema FK references. Diagram.cascade() and Table.drop() now call load_all_downstream() before building the dependency graph. Includes integration test: two schemas where the downstream schema has an FK to the upstream schema, verifying that cascade delete discovers and deletes from both. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 00bdc82 commit ee344d9

File tree

7 files changed

+147
-4
lines changed

7 files changed

+147
-4
lines changed

src/datajoint/adapters/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,26 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
830830
"""
831831
...
832832

833+
def find_downstream_schemas_sql(self, schemas_list: str) -> str:
834+
"""
835+
Generate query to find schemas with FK references to the given schemas.
836+
837+
Used to discover unloaded schemas that depend on loaded ones.
838+
839+
Parameters
840+
----------
841+
schemas_list : str
842+
Comma-separated, quoted schema names for an IN clause.
843+
844+
Returns
845+
-------
846+
str
847+
SQL query returning rows with a single column ``schema_name``
848+
containing distinct schema names that reference the given schemas.
849+
"""
850+
raise NotImplementedError
851+
...
852+
833853
@abstractmethod
834854
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
835855
"""

src/datajoint/adapters/mysql.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,15 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
687687
f"OR referenced_table_schema is not NULL AND table_schema in ({schemas_list}))"
688688
)
689689

690+
def find_downstream_schemas_sql(self, schemas_list: str) -> str:
691+
"""Find schemas with FK references to the given schemas."""
692+
return (
693+
f"SELECT DISTINCT table_schema as schema_name "
694+
f"FROM information_schema.key_column_usage "
695+
f"WHERE referenced_table_schema IN ({schemas_list}) "
696+
f"AND table_schema NOT IN ({schemas_list})"
697+
)
698+
690699
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
691700
"""Query to get FK constraint details from information_schema."""
692701
return (

src/datajoint/adapters/postgres.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,20 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
847847
f"ORDER BY c.conname, cols.ord"
848848
)
849849

850+
def find_downstream_schemas_sql(self, schemas_list: str) -> str:
851+
"""Find schemas with FK references to the given schemas."""
852+
return (
853+
f"SELECT DISTINCT ns1.nspname as schema_name "
854+
f"FROM pg_constraint c "
855+
f"JOIN pg_class cl1 ON c.conrelid = cl1.oid "
856+
f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid "
857+
f"JOIN pg_class cl2 ON c.confrelid = cl2.oid "
858+
f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid "
859+
f"WHERE c.contype = 'f' "
860+
f"AND ns2.nspname IN ({schemas_list}) "
861+
f"AND ns1.nspname NOT IN ({schemas_list})"
862+
)
863+
850864
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
851865
"""
852866
Query to get FK constraint details from information_schema.

src/datajoint/dependencies.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ def clear(self) -> None:
140140
self._node_alias_count = itertools.count() # reset alias IDs for consistency
141141
super().clear()
142142

143-
def load(self, force: bool = True) -> None:
143+
def load(self, force: bool = True, schema_names: set[str] | None = None) -> None:
144144
"""
145-
Load dependencies for all loaded schemas.
145+
Load dependencies for the given schemas.
146146
147147
Called before operations requiring dependencies: delete, drop,
148148
populate, progress.
@@ -151,6 +151,8 @@ def load(self, force: bool = True) -> None:
151151
----------
152152
force : bool, optional
153153
If True (default), reload even if already loaded.
154+
schema_names : set[str], optional
155+
Schema names to load. If None, uses all activated schemas.
154156
"""
155157
# reload from scratch to prevent duplication of renamed edges
156158
if self._loaded and not force:
@@ -162,7 +164,11 @@ def load(self, force: bool = True) -> None:
162164
adapter = self._conn.adapter
163165

164166
# Build schema list for IN clause
165-
schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas)
167+
names = schema_names if schema_names is not None else set(self._conn.schemas)
168+
if not names:
169+
self._loaded = True
170+
return
171+
schemas_list = ", ".join(adapter.quote_string(s) for s in names)
166172

167173
# Load primary keys and foreign keys via adapter methods
168174
# Note: Both PyMySQL and psycopg use %s placeholders, so escape % as %%
@@ -220,6 +226,33 @@ def load(self, force: bool = True) -> None:
220226
raise DataJointError("DataJoint can only work with acyclic dependencies")
221227
self._loaded = True
222228

229+
def load_all_downstream(self) -> None:
230+
"""
231+
Load dependencies including all downstream schemas reachable via FK chains.
232+
233+
Iteratively discovers schemas that reference the currently loaded
234+
schemas, expanding the dependency graph until no new schemas are
235+
found. This ensures that cascade delete and drop reach all
236+
dependent tables, even those in schemas that haven't been
237+
explicitly activated.
238+
"""
239+
adapter = self._conn.adapter
240+
known_schemas = set(self._conn.schemas)
241+
if not known_schemas:
242+
self.load()
243+
return
244+
245+
max_iterations = 50
246+
for _ in range(max_iterations):
247+
schemas_list = ", ".join(adapter.quote_string(s) for s in known_schemas)
248+
result = self._conn.query(adapter.find_downstream_schemas_sql(schemas_list))
249+
new_schemas = {row[0] for row in result} - known_schemas
250+
if not new_schemas:
251+
break
252+
known_schemas |= new_schemas
253+
254+
self.load(force=True, schema_names=known_schemas)
255+
223256
def topo_sort(self) -> list[str]:
224257
"""
225258
Return table names in topological order.

src/datajoint/diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def cascade(cls, table_expr, part_integrity="enforce"):
347347
>>> dj.Diagram.cascade(Session & 'subject_id=1')
348348
"""
349349
conn = table_expr.connection
350-
conn.dependencies.load()
350+
conn.dependencies.load_all_downstream()
351351
node = table_expr.full_table_name
352352

353353
result = cls.__new__(cls)

src/datajoint/table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,7 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
11701170
import networkx as nx
11711171
from .diagram import Diagram
11721172

1173+
self.connection.dependencies.load_all_downstream()
11731174
diagram = Diagram(self)
11741175
# Expand to include all descendants (cross-schema)
11751176
descendants = set(nx.descendants(diagram, self.full_table_name)) | {self.full_table_name}

tests/integration/test_cascade_delete.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,69 @@ class Child(dj.Manual):
226226
# Data must still be intact
227227
assert len(Parent()) == 2
228228
assert len(Child()) == 3
229+
230+
231+
def test_cascade_discovers_downstream_schema(connection_by_backend, db_creds_by_backend):
232+
"""Cascade delete discovers and includes tables in unloaded downstream schemas."""
233+
import time
234+
235+
backend = db_creds_by_backend["backend"]
236+
test_id = str(int(time.time() * 1000))[-8:]
237+
238+
upstream_name = f"djtest_upstream_{backend}_{test_id}"[:64]
239+
downstream_name = f"djtest_downstream_{backend}_{test_id}"[:64]
240+
241+
qi = connection_by_backend.adapter.quote_identifier
242+
243+
# Clean up any previous runs
244+
for name in (downstream_name, upstream_name):
245+
try:
246+
connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}")
247+
except Exception:
248+
pass
249+
250+
# Create upstream schema and table
251+
upstream = dj.Schema(upstream_name, connection=connection_by_backend)
252+
253+
@upstream
254+
class Parent(dj.Manual):
255+
definition = """
256+
parent_id : int
257+
---
258+
name : varchar(100)
259+
"""
260+
261+
# Create downstream schema with FK to upstream — separate schema object
262+
downstream = dj.Schema(downstream_name, connection=connection_by_backend)
263+
264+
@downstream
265+
class Child(dj.Manual):
266+
definition = """
267+
-> Parent
268+
child_id : int
269+
---
270+
data : varchar(100)
271+
"""
272+
273+
# Insert data
274+
Parent.insert1(dict(parent_id=1, name="Alice"))
275+
Child.insert1(dict(parent_id=1, child_id=1, data="row1"))
276+
Child.insert1(dict(parent_id=1, child_id=2, data="row2"))
277+
278+
# Verify cascade preview discovers the downstream schema
279+
counts = dj.Diagram.cascade(Parent & "parent_id=1").counts()
280+
assert Parent.full_table_name in counts
281+
assert Child.full_table_name in counts
282+
assert counts[Child.full_table_name] == 2
283+
284+
# Verify actual delete cascades across schemas
285+
(Parent & "parent_id=1").delete()
286+
assert len(Parent()) == 0
287+
assert len(Child()) == 0
288+
289+
# Clean up
290+
for name in (downstream_name, upstream_name):
291+
try:
292+
connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}")
293+
except Exception:
294+
pass

0 commit comments

Comments
 (0)