Skip to content

Commit 609daeb

Browse files
fix: address review — quoting, error message, and initialization issues
1. add_parts(): strip both backticks and double quotes from identifiers so part-table detection works on PostgreSQL. 2. Extract _split_full_name() helper replacing 8 instances of the fragile full_name.replace('"', '`').split('`') pattern in visualization/collapse methods. Works with both quoting styles. 3. Fix error message in __init__: repr(source) not repr(source[0]) — source is a schema/module, not a sequence. 4. Initialize _part_integrity="enforce" in __init__ and _from_table instead of relying on getattr fallback in the copy constructor. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 685a641 commit 609daeb

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

src/datajoint/diagram.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
logger = logging.getLogger(__name__.split(".")[0])
4848

4949

50+
def _split_full_name(full_name: str) -> tuple[str, str]:
51+
"""Split a quoted full table name into (schema, table) regardless of quote style."""
52+
parts = full_name.strip('`"').split("`.`") if "`" in full_name else full_name.strip('"').split('"."')
53+
if len(parts) == 2:
54+
return parts[0], parts[1]
55+
# Fallback: strip all quotes and split on dot
56+
stripped = full_name.replace("`", "").replace('"', "")
57+
schema, _, table = stripped.partition(".")
58+
return schema, table
59+
60+
5061
class Diagram(nx.DiGraph): # noqa: C901
5162
"""
5263
Schema diagram as a directed acyclic graph (DAG).
@@ -99,7 +110,7 @@ def __init__(self, source, context=None) -> None:
99110
self._cascade_restrictions = copy_module.deepcopy(source._cascade_restrictions)
100111
self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions)
101112
self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs)
102-
self._part_integrity = getattr(source, "_part_integrity", "enforce")
113+
self._part_integrity = source._part_integrity
103114
super().__init__(source)
104115
return
105116

@@ -118,7 +129,7 @@ def __init__(self, source, context=None) -> None:
118129
try:
119130
connection = source.schema.connection
120131
except AttributeError:
121-
raise DataJointError("Could not find database connection in %s" % repr(source[0]))
132+
raise DataJointError("Could not find database connection in %s" % repr(source))
122133

123134
# initialize graph from dependencies
124135
connection.dependencies.load()
@@ -127,6 +138,7 @@ def __init__(self, source, context=None) -> None:
127138
self._cascade_restrictions = {}
128139
self._restrict_conditions = {}
129140
self._restriction_attrs = {}
141+
self._part_integrity = "enforce"
130142

131143
# Enumerate nodes from all the items in the list
132144
self.nodes_to_show = set()
@@ -194,6 +206,7 @@ def _from_table(cls, table_expr) -> "Diagram":
194206
result._cascade_restrictions = {}
195207
result._restrict_conditions = {}
196208
result._restriction_attrs = {}
209+
result._part_integrity = "enforce"
197210
return result
198211

199212
def add_parts(self) -> "Diagram":
@@ -207,8 +220,8 @@ def add_parts(self) -> "Diagram":
207220
"""
208221

209222
def is_part(part, master):
210-
part = [s.strip("`") for s in part.split(".")]
211-
master = [s.strip("`") for s in master.split(".")]
223+
part = [s.strip('`"') for s in part.split(".")]
224+
master = [s.strip('`"') for s in master.split(".")]
212225
return master[0] == part[0] and master[1] + "__" == part[1][: len(master[1]) + 2]
213226

214227
self = Diagram(self) # copy
@@ -769,9 +782,8 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
769782
for class_name in nodes_to_collapse:
770783
full_name = class_to_full.get(class_name)
771784
if full_name:
772-
parts = full_name.replace('"', "`").split("`")
773-
if len(parts) >= 2:
774-
schema_name = parts[1]
785+
schema_name, _ = _split_full_name(full_name)
786+
if schema_name:
775787
if schema_name not in collapsed_by_schema:
776788
collapsed_by_schema[schema_name] = []
777789
collapsed_by_schema[schema_name].append(class_name)
@@ -794,9 +806,8 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
794806
for node in graph.nodes():
795807
full_name = class_to_full.get(node)
796808
if full_name:
797-
parts = full_name.replace('"', "`").split("`")
798-
if len(parts) >= 2:
799-
db_schema = parts[1]
809+
db_schema, _ = _split_full_name(full_name)
810+
if db_schema:
800811
cls = self._resolve_class(node)
801812
if cls is not None and hasattr(cls, "__module__"):
802813
module_name = cls.__module__.split(".")[-1]
@@ -839,9 +850,8 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
839850
for node in graph.nodes():
840851
full_name = class_to_full.get(node)
841852
if full_name:
842-
parts = full_name.replace('"', "`").split("`")
843-
if len(parts) >= 2 and node in nodes_to_collapse:
844-
schema_name = parts[1]
853+
schema_name, _ = _split_full_name(full_name)
854+
if schema_name and node in nodes_to_collapse:
845855
node_mapping[node] = collapsed_labels[schema_name]
846856
else:
847857
node_mapping[node] = node
@@ -854,9 +864,8 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
854864
neighbor = next(iter(neighbors))
855865
full_name = class_to_full.get(neighbor)
856866
if full_name:
857-
parts = full_name.replace('"', "`").split("`")
858-
if len(parts) >= 2:
859-
schema_name = parts[1]
867+
schema_name, _ = _split_full_name(full_name)
868+
if schema_name:
860869
node_mapping[node] = collapsed_labels[schema_name]
861870
continue
862871
node_mapping[node] = node
@@ -981,10 +990,8 @@ def make_dot(self):
981990
schema_modules = {} # schema_name -> set of module names
982991

983992
for full_name in self.nodes_to_show:
984-
# Extract schema from full table name like `schema`.`table` or "schema"."table"
985-
parts = full_name.replace('"', "`").split("`")
986-
if len(parts) >= 2:
987-
schema_name = parts[1] # schema is between first pair of backticks
993+
schema_name, _ = _split_full_name(full_name)
994+
if schema_name:
988995
class_name = lookup_class_name(full_name, self.context) or full_name
989996
schema_map[class_name] = schema_name
990997

@@ -1248,9 +1255,8 @@ def make_mermaid(self) -> str:
12481255
schema_modules = {} # schema_name -> set of module names
12491256

12501257
for full_name in self.nodes_to_show:
1251-
parts = full_name.replace('"', "`").split("`")
1252-
if len(parts) >= 2:
1253-
schema_name = parts[1]
1258+
schema_name, _ = _split_full_name(full_name)
1259+
if schema_name:
12541260
class_name = lookup_class_name(full_name, self.context) or full_name
12551261
schema_map[class_name] = schema_name
12561262

0 commit comments

Comments
 (0)