Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions buckaroo/polars_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import polars as pl


# Polars uses "full" instead of "outer"
_HOW_MAP = {"outer": "full", "full": "full", "inner": "inner", "left": "left", "right": "right"}


def col_join_dfs(df1, df2, join_columns, how):
"""Join two Polars DataFrames and compute column-level diff statistics.

Parameters
----------
df1, df2 : pl.DataFrame
The two DataFrames to compare.
join_columns : str or list[str]
Column name(s) to join on.
how : str
Join type ('inner', 'outer', 'left', 'right').

Returns
-------
m_df : pl.DataFrame
Merged DataFrame with membership and equality columns.
column_config_overrides : dict
Buckaroo column config for styling.
eqs : dict
Per-column diff summary.
"""
if isinstance(join_columns, str):
join_columns = [join_columns]

df2_suffix = "|df2"
for col in df1.columns + df2.columns:
if df2_suffix in col:
raise ValueError(
f"|df2 is a sentinel column name used by this tool, "
f"and can't be used in a dataframe passed in, {col} violates that constraint"
)

df1_name, df2_name = "df_1", "df_2"

# Validate join keys are unique to prevent cartesian explosion
if not df1.select(pl.struct(join_columns).is_unique().all()).item():
raise ValueError(
f"Duplicate join keys found in df1 on columns {join_columns}. "
"Join keys must be unique in each dataframe for a valid comparison."
)
if not df2.select(pl.struct(join_columns).is_unique().all()).item():
raise ValueError(
f"Duplicate join keys found in df2 on columns {join_columns}. "
"Join keys must be unique in each dataframe for a valid comparison."
)

pl_how = _HOW_MAP.get(how, how)

# Add non-null marker columns before the join so membership detection
# works even when join keys contain nulls.
_left_marker = "__bk_left"
_right_marker = "__bk_right"
df1_marked = df1.with_columns(pl.lit(True).alias(_left_marker))
df2_marked = df2.with_columns(pl.lit(True).alias(_right_marker))

m_df = df1_marked.join(
df2_marked, on=join_columns, how=pl_how, suffix=df2_suffix, coalesce=False,
)

# Compute membership from marker columns (immune to null join keys)
# left marker present => came from df1, right marker present => came from df2
m_df = m_df.with_columns(
pl.when(pl.col(_left_marker).is_not_null() & pl.col(_right_marker).is_not_null())
.then(3)
.when(pl.col(_left_marker).is_not_null())
.then(1)
.otherwise(2)
.cast(pl.Int8)
.alias("membership")
).drop(_left_marker, _right_marker)

# Coalesce join keys and drop suffixed copies
for jc in join_columns:
jc_right = f"{jc}{df2_suffix}"
if jc_right in m_df.columns:
m_df = m_df.with_columns(pl.coalesce(jc, jc_right).alias(jc)).drop(jc_right)

# Build unified column order
col_order = df1.columns.copy()
for col in df2.columns:
if col not in col_order:
col_order.append(col)

# Compute diff stats from key-aligned rows
eqs = {}
both_mask = m_df["membership"] == 3
m_both = m_df.filter(both_mask)
for col in col_order:
if col in join_columns:
eqs[col] = {"diff_count": "join_key"}
elif col in df1.columns and col in df2.columns:
df2_col = f"{col}{df2_suffix}"
if df2_col in m_df.columns:
eqs[col] = {
"diff_count": int(
m_both.select(pl.col(col).ne_missing(pl.col(df2_col)).sum()).item()
)
}
else:
eqs[col] = {"diff_count": 0}
else:
if col in df1.columns:
eqs[col] = {"diff_count": df1_name}
else:
eqs[col] = {"diff_count": df2_name}

column_config_overrides = {}
eq_map = ["pink", "#73ae80", "#90b2b3", "#6c83b5"]

column_config_overrides["membership"] = {"merge_rule": "hidden"}

both_columns = [c for c in m_df.columns if c.endswith(df2_suffix)]
for b_col in both_columns:
a_col = b_col.removesuffix(df2_suffix)
eq_col = f"{a_col}|eq"
m_df = m_df.with_columns(
(pl.col(a_col).eq_missing(pl.col(b_col)).cast(pl.Int8) * 4 + pl.col("membership"))
.alias(eq_col)
)

column_config_overrides[b_col] = {"merge_rule": "hidden"}
column_config_overrides[eq_col] = {"merge_rule": "hidden"}
column_config_overrides[a_col] = {
"tooltip_config": {"tooltip_type": "simple", "val_column": b_col},
"color_map_config": {
"color_rule": "color_categorical",
"map_name": eq_map,
"val_column": eq_col,
},
}

for jc in join_columns:
column_config_overrides[jc] = {
"color_map_config": {
"color_rule": "color_categorical",
"map_name": eq_map,
"val_column": "membership",
}
}

return m_df, column_config_overrides, eqs
169 changes: 169 additions & 0 deletions tests/unit/polars_compare_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import polars as pl
import pytest

from buckaroo.polars_compare import col_join_dfs


def test_single_join_key():
"""col_join_dfs works with a single join key."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 25, 30]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert "membership" in m_df.columns
assert (m_df["membership"] == 3).all()
assert eqs["val"]["diff_count"] == 1
assert eqs["id"]["diff_count"] == "join_key"


def test_multi_key_join():
"""col_join_dfs works with multiple join columns."""
df1 = pl.DataFrame(
{"account_id": [1, 1, 2], "as_of_date": ["2024-01", "2024-02", "2024-01"], "amount": [100, 200, 300]}
)
df2 = pl.DataFrame(
{"account_id": [1, 1, 2], "as_of_date": ["2024-01", "2024-02", "2024-01"], "amount": [100, 250, 300]}
)

m_df, overrides, eqs = col_join_dfs(
df1, df2, join_columns=["account_id", "as_of_date"], how="outer"
)

assert m_df.height == 3
assert (m_df["membership"] == 3).all()
assert eqs["amount"]["diff_count"] == 1
assert eqs["account_id"]["diff_count"] == "join_key"
assert eqs["as_of_date"]["diff_count"] == "join_key"
assert "account_id" in overrides
assert "as_of_date" in overrides


def test_outer_join_membership():
"""Rows only in one side get correct membership values."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [2, 3, 4], "val": [20, 30, 40]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert m_df.height == 4
rows = {row["id"]: row["membership"] for row in m_df.iter_rows(named=True)}
assert rows[1] == 1 # df1 only
assert rows[2] == 3 # both
assert rows[3] == 3 # both
assert rows[4] == 2 # df2 only


def test_reordered_rows():
"""Diff stats are correct even when row order differs."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [3, 1, 2], "val": [30, 10, 20]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert eqs["val"]["diff_count"] == 0
assert (m_df["membership"] == 3).all()


def test_one_sided_extra_columns():
"""Columns only in one df are reported correctly."""
df1 = pl.DataFrame({"id": [1, 2], "x": [10, 20]})
df2 = pl.DataFrame({"id": [1, 2], "y": [30, 40]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert eqs["x"]["diff_count"] == "df_1"
assert eqs["y"]["diff_count"] == "df_2"


def test_string_join_columns_normalized():
"""A single string join_columns is accepted."""
df1 = pl.DataFrame({"key": [1, 2], "val": [10, 20]})
df2 = pl.DataFrame({"key": [1, 2], "val": [10, 25]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns="key", how="inner")

assert eqs["val"]["diff_count"] == 1


def test_sentinel_column_rejected():
"""DataFrames containing '|df2' in column names are rejected."""
df1 = pl.DataFrame({"id": [1], "bad|df2": [10]})
df2 = pl.DataFrame({"id": [1], "val": [20]})

with pytest.raises(ValueError, match="\\|df2"):
col_join_dfs(df1, df2, join_columns=["id"], how="outer")


def test_inner_join():
"""Inner join only keeps matched rows."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [2, 3, 4], "val": [20, 35, 40]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="inner")

assert m_df.height == 2
assert (m_df["membership"] == 3).all()
assert eqs["val"]["diff_count"] == 1


def test_null_values_in_data():
"""Null-heavy comparisons don't crash and report diffs."""
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [None, 20, None]})
df2 = pl.DataFrame({"id": [1, 2, 3], "val": [None, None, 30]})

m_df, overrides, eqs = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

assert (m_df["membership"] == 3).all()
assert eqs["val"]["diff_count"] >= 2


def test_duplicate_join_keys_rejected():
"""Duplicate join keys raise ValueError."""
df1 = pl.DataFrame({"id": [1, 1, 2], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})

with pytest.raises(ValueError, match="Duplicate join keys"):
col_join_dfs(df1, df2, join_columns=["id"], how="outer")

df1_ok = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
df2_dup = pl.DataFrame({"id": [1, 1, 2], "val": [10, 20, 30]})

with pytest.raises(ValueError, match="Duplicate join keys"):
col_join_dfs(df1_ok, df2_dup, join_columns=["id"], how="outer")


def test_how_outer_alias():
"""Both 'outer' and 'full' are accepted as how values."""
df1 = pl.DataFrame({"id": [1, 2], "val": [10, 20]})
df2 = pl.DataFrame({"id": [2, 3], "val": [20, 30]})

m1, _, _ = col_join_dfs(df1, df2, join_columns=["id"], how="outer")
m2, _, _ = col_join_dfs(df1, df2, join_columns=["id"], how="full")

assert m1.height == m2.height == 3


def test_nullable_join_key_membership():
"""Membership is correct when the join key itself contains nulls.

Polars does not match null keys (null != null in join semantics),
so null-keyed rows appear as one-sided. The marker-based membership
detection must still classify them correctly as df1-only / df2-only.
"""
df1 = pl.DataFrame({"id": [None, 2, 3], "val": [10, 20, 30]})
df2 = pl.DataFrame({"id": [None, 3, 4], "val": [10, 30, 40]})

m_df, _, _ = col_join_dfs(df1, df2, join_columns=["id"], how="outer")

rows_by_id = {}
for row in m_df.iter_rows(named=True):
key = (row["id"], row["membership"])
rows_by_id[key] = True

assert (3, 3) in rows_by_id # both
assert (2, 1) in rows_by_id # df1 only
assert (4, 2) in rows_by_id # df2 only
# Null keys don't match in polars joins — each null-keyed row is one-sided
assert (None, 1) in rows_by_id # df1's null key → df1 only
assert (None, 2) in rows_by_id # df2's null key → df2 only
Loading