|
6 | 6 | import polars as pl |
7 | 7 | import pytest |
8 | 8 |
|
9 | | -from diffly._conditions import _can_compare_dtypes, condition_equal_columns |
| 9 | +from diffly._conditions import ( |
| 10 | + _can_compare_dtypes, |
| 11 | + _needs_element_wise_comparison, |
| 12 | + condition_equal_columns, |
| 13 | +) |
10 | 14 | from diffly.comparison import compare_frames |
11 | 15 |
|
12 | 16 |
|
@@ -512,6 +516,45 @@ def test_condition_equal_columns_lists_only_inner() -> None: |
512 | 516 | assert actual.to_list() == [True, False] |
513 | 517 |
|
514 | 518 |
|
| 519 | +def test_condition_equal_columns_list_of_different_enums() -> None: |
| 520 | + # Arrange |
| 521 | + first_enum = pl.Enum(["one", "two"]) |
| 522 | + second_enum = pl.Enum(["one", "two", "three"]) |
| 523 | + |
| 524 | + lhs = pl.DataFrame( |
| 525 | + {"pk": [1, 2], "a": [["one", "two"], ["one", "one"]]}, |
| 526 | + schema_overrides={"a": pl.List(first_enum)}, |
| 527 | + ) |
| 528 | + rhs = pl.DataFrame( |
| 529 | + {"pk": [1, 2], "a": [["one", "two"], ["one", "three"]]}, |
| 530 | + schema_overrides={"a": pl.List(second_enum)}, |
| 531 | + ) |
| 532 | + c = compare_frames(lhs, rhs, primary_key="pk") |
| 533 | + |
| 534 | + # Act |
| 535 | + lhs = lhs.rename({"a": "a_left"}) |
| 536 | + rhs = rhs.rename({"a": "a_right"}) |
| 537 | + actual = ( |
| 538 | + lhs.join(rhs, on="pk", maintain_order="left") |
| 539 | + .select( |
| 540 | + condition_equal_columns( |
| 541 | + "a", |
| 542 | + dtype_left=lhs.schema["a_left"], |
| 543 | + dtype_right=rhs.schema["a_right"], |
| 544 | + max_list_length=c._max_list_lengths_by_column.get("a"), |
| 545 | + abs_tol=c.abs_tol_by_column["a"], |
| 546 | + rel_tol=c.rel_tol_by_column["a"], |
| 547 | + ) |
| 548 | + ) |
| 549 | + .to_series() |
| 550 | + ) |
| 551 | + |
| 552 | + # Assert |
| 553 | + assert c._max_list_lengths_by_column == {"a": 2} |
| 554 | + assert _needs_element_wise_comparison(first_enum, second_enum) |
| 555 | + assert actual.to_list() == [True, False] |
| 556 | + |
| 557 | + |
515 | 558 | @pytest.mark.parametrize( |
516 | 559 | ("dtype_left", "dtype_right", "can_compare_dtypes"), |
517 | 560 | [ |
|
0 commit comments