Skip to content

Commit 9b23fef

Browse files
Merge branch 'main' into docs_examples
2 parents e636656 + cccfad4 commit 9b23fef

3 files changed

Lines changed: 229 additions & 116 deletions

File tree

diffly/_conditions.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import datetime as dt
55
from collections.abc import Mapping
6+
from typing import cast
67

78
import polars as pl
89
from polars.datatypes import DataType, DataTypeClass
@@ -206,12 +207,7 @@ def _compare_sequence_columns(
206207
n_elements = dtype_right.shape[0]
207208
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
208209
else: # pl.List vs pl.List
209-
if not isinstance(max_list_length, int):
210-
# Fallback for nested list comparisons where no max_list_length is
211-
# available: perform a direct equality comparison without element-wise
212-
# unrolling.
213-
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
214-
n_elements = max_list_length
210+
n_elements = cast(int, max_list_length)
215211
has_same_length = col_left.list.len().eq_missing(col_right.list.len())
216212

217213
if n_elements == 0:
@@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
232228
abs_tol=abs_tol,
233229
rel_tol=rel_tol,
234230
abs_tol_temporal=abs_tol_temporal,
235-
max_list_length=None,
231+
max_list_length=max_list_length,
236232
)
237233
for i in range(n_elements)
238234
]

diffly/comparison.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -993,22 +993,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
993993

994994
@cached_property
995995
def _max_list_lengths_by_column(self) -> dict[str, int]:
996-
list_columns = [
997-
col
998-
for col in self._other_common_columns
999-
if isinstance(self.left_schema[col], pl.List)
1000-
and isinstance(self.right_schema[col], pl.List)
1001-
]
1002-
if not list_columns:
996+
"""Max list length across all nesting levels, for columns where both sides
997+
contain a List anywhere in their type tree."""
998+
left_exprs: list[pl.Expr] = []
999+
right_exprs: list[pl.Expr] = []
1000+
columns: list[str] = []
1001+
1002+
for col in self._other_common_columns:
1003+
col_left = _list_length_exprs(pl.col(col), self.left_schema[col])
1004+
col_right = _list_length_exprs(pl.col(col), self.right_schema[col])
1005+
if not (col_left and col_right):
1006+
continue
1007+
columns.append(col)
1008+
left_exprs.append(pl.max_horizontal(col_left).alias(col))
1009+
right_exprs.append(pl.max_horizontal(col_right).alias(col))
1010+
1011+
if not columns:
10031012
return {}
10041013

1005-
exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
10061014
[left_max, right_max] = pl.collect_all(
1007-
[self.left.select(exprs), self.right.select(exprs)]
1015+
[self.left.select(left_exprs), self.right.select(right_exprs)]
10081016
)
10091017
return {
10101018
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
1011-
for col in list_columns
1019+
for col in columns
10121020
}
10131021

10141022
def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
@@ -1189,3 +1197,21 @@ def right_only(self) -> Schema:
11891197
{'score': Int64}
11901198
"""
11911199
return self.right() - self.left()
1200+
1201+
1202+
def _list_length_exprs(
1203+
expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass
1204+
) -> list[pl.Expr]:
1205+
"""Collect max-list-length scalar expressions for every List level in the type
1206+
tree."""
1207+
if isinstance(dtype, pl.List):
1208+
return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)]
1209+
if isinstance(dtype, pl.Array):
1210+
return _list_length_exprs(expr.explode(), dtype.inner)
1211+
if isinstance(dtype, pl.Struct):
1212+
return [
1213+
e
1214+
for field in dtype.fields
1215+
for e in _list_length_exprs(expr.struct[field.name], field.dtype)
1216+
]
1217+
return []

0 commit comments

Comments
 (0)