@@ -993,22 +993,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
993993
994994 @cached_property
995995 def _max_list_lengths_by_column (self ) -> dict [str , int ]:
996- list_columns = [
997- col
998- for col in self ._other_common_columns
999- if isinstance (self .left_schema [col ], pl .List )
1000- and isinstance (self .right_schema [col ], pl .List )
1001- ]
1002- if not list_columns :
996+ """Max list length across all nesting levels, for columns where both sides
997+ contain a List anywhere in their type tree."""
998+ left_exprs : list [pl .Expr ] = []
999+ right_exprs : list [pl .Expr ] = []
1000+ columns : list [str ] = []
1001+
1002+ for col in self ._other_common_columns :
1003+ col_left = _list_length_exprs (pl .col (col ), self .left_schema [col ])
1004+ col_right = _list_length_exprs (pl .col (col ), self .right_schema [col ])
1005+ if not (col_left and col_right ):
1006+ continue
1007+ columns .append (col )
1008+ left_exprs .append (pl .max_horizontal (col_left ).alias (col ))
1009+ right_exprs .append (pl .max_horizontal (col_right ).alias (col ))
1010+
1011+ if not columns :
10031012 return {}
10041013
1005- exprs = [pl .col (col ).list .len ().max ().alias (col ) for col in list_columns ]
10061014 [left_max , right_max ] = pl .collect_all (
1007- [self .left .select (exprs ), self .right .select (exprs )]
1015+ [self .left .select (left_exprs ), self .right .select (right_exprs )]
10081016 )
10091017 return {
10101018 col : max (int (left_max [col ].item () or 0 ), int (right_max [col ].item () or 0 ))
1011- for col in list_columns
1019+ for col in columns
10121020 }
10131021
10141022 def _condition_equal_rows (self , columns : list [str ]) -> pl .Expr :
@@ -1189,3 +1197,21 @@ def right_only(self) -> Schema:
11891197 {'score': Int64}
11901198 """
11911199 return self .right () - self .left ()
1200+
1201+
1202+ def _list_length_exprs (
1203+ expr : pl .Expr , dtype : pl .DataType | pl .datatypes .DataTypeClass
1204+ ) -> list [pl .Expr ]:
1205+ """Collect max-list-length scalar expressions for every List level in the type
1206+ tree."""
1207+ if isinstance (dtype , pl .List ):
1208+ return [expr .list .len ().max (), * _list_length_exprs (expr .explode (), dtype .inner )]
1209+ if isinstance (dtype , pl .Array ):
1210+ return _list_length_exprs (expr .explode (), dtype .inner )
1211+ if isinstance (dtype , pl .Struct ):
1212+ return [
1213+ e
1214+ for field in dtype .fields
1215+ for e in _list_length_exprs (expr .struct [field .name ], field .dtype )
1216+ ]
1217+ return []
0 commit comments