From 94cb8284d55b52b6b9031fd18eecd3af38d2283f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gro=C3=9Fer?= Date: Thu, 26 Mar 2026 10:53:46 +0100 Subject: [PATCH] functionality update --- README.md | 2 +- docs/api/backends.md | 2 +- docs/api/index.md | 2 + docs/api/ops/math.md | 24 +++++ docs/getting-started/quickstart.md | 2 +- docs/index.md | 2 +- tests/test_duckdb.py | 125 ++++++++++++++++++++++ tests/test_math.py | 164 +++++++++++++++++++++++++++++ transformplan/backends/base.py | 11 ++ transformplan/backends/duckdb.py | 20 ++++ transformplan/backends/polars.py | 19 ++++ transformplan/chunking.py | 1 + transformplan/ops/math.py | 40 +++++++ transformplan/validation.py | 46 ++++++++ 14 files changed, 456 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 476021a..c925559 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ ```python from transformplan import TransformPlan, Col -# Build readable pipelines with 88 chainable operations +# Build readable pipelines with 89 chainable operations plan = ( TransformPlan() # Standardize column names diff --git a/docs/api/backends.md b/docs/api/backends.md index 8537f18..1860eb2 100644 --- a/docs/api/backends.md +++ b/docs/api/backends.md @@ -1,6 +1,6 @@ # Backends -TransformPlan uses a pluggable backend system. Each backend implements the `Backend` ABC, providing all 88 operations plus meta methods for hashing, schema inspection, and type classification. +TransformPlan uses a pluggable backend system. Each backend implements the `Backend` ABC, providing all 89 operations plus meta methods for hashing, schema inspection, and type classification. ## Overview diff --git a/docs/api/index.md b/docs/api/index.md index 2806279..90fd566 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -96,6 +96,8 @@ All TransformPlan operations at a glance. Click method names for detailed docume | [`math_log`](ops/math.md) | Logarithmic transform | | [`math_sqrt`](ops/math.md) | Square root transform | | [`math_power`](ops/math.md) | Power transform | +| [`math_diff_from_agg`](ops/math.md) | Difference from a group aggregate (min, mean, etc.) | +| [`math_diff_lag`](ops/math.md) | Row-to-row difference using lag (numeric or datetime) | | [`math_winsorize`](ops/math.md) | Clip values to percentiles or bounds | ### Row Operations diff --git a/docs/api/ops/math.md b/docs/api/ops/math.md index 70da76d..a3acb58 100644 --- a/docs/api/ops/math.md +++ b/docs/api/ops/math.md @@ -40,6 +40,7 @@ plan = ( - math_cumsum - math_rank - math_diff_from_agg + - math_diff_lag - math_standardize - math_minmax - math_robust_scale @@ -162,6 +163,29 @@ plan = TransformPlan().math_diff_from_agg( agg="max", new_column="diff_from_max", ) + +# Row-to-row difference (lag) +plan = TransformPlan().math_diff_lag( + column="timestamp", + order_by="timestamp", + new_column="time_between", + group_by="patient_id", +) + +# Numeric change ordered by date +plan = TransformPlan().math_diff_lag( + column="price", + order_by="date", + new_column="daily_change", +) + +# Lag of 2 rows +plan = TransformPlan().math_diff_lag( + column="value", + order_by="seq", + new_column="diff_2", + lag=2, +) ``` ### Scaling Operations diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index c1e37f4..9ba88bf 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -74,7 +74,7 @@ print(df_result) ## Using the DuckDB Backend -TransformPlan supports DuckDB as an alternative backend. All 88 operations, validation, and dry-run work identically — the same plan works with both Polars DataFrames and DuckDB relations. Simply pass the backend at execution time: +TransformPlan supports DuckDB as an alternative backend. All 89 operations, validation, and dry-run work identically — the same plan works with both Polars DataFrames and DuckDB relations. Simply pass the backend at execution time: ```python import duckdb diff --git a/docs/index.md b/docs/index.md index 1190eec..9ce7bf4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -21,7 +21,7 @@ TransformPlan tracks transformation history, validates operations against DataFr ```python from transformplan import TransformPlan, Col -# Build readable pipelines with 88 chainable operations +# Build readable pipelines with 89 chainable operations plan = ( TransformPlan() # Standardize column names diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index ce69b16..90b5b1c 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -1617,6 +1617,131 @@ def test_invalid_agg_raises( ) +class TestMathDiffLag: + """Tests for math_diff_lag with DuckDB backend.""" + + def test_numeric_basic( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql( + "SELECT * FROM (VALUES (1, 10), (2, 30), (3, 35), (4, 50)) AS t(id, val)" + ) + result, _ = ( + TransformPlan() + .math_diff_lag("val", order_by="id", new_column="diff") + .process(rel, backend=backend) + ) + vals = _col_values(result, "diff") + assert vals[0] is None + assert vals[1:] == [20, 5, 15] + + def test_numeric_lag2( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql( + "SELECT * FROM (VALUES (1, 10), (2, 30), (3, 35), (4, 50)) AS t(id, val)" + ) + result, _ = ( + TransformPlan() + .math_diff_lag("val", order_by="id", new_column="diff", lag=2) + .process(rel, backend=backend) + ) + vals = _col_values(result, "diff") + assert vals[0] is None + assert vals[1] is None + assert vals[2:] == [25, 20] + + def test_grouped_numeric( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql( + "SELECT * FROM (VALUES " + "('A', 1, 10), ('A', 2, 30), ('A', 3, 35), " + "('B', 1, 100), ('B', 2, 150), ('B', 3, 160)" + ") AS t(grp, seq, val)" + ) + result, _ = ( + TransformPlan() + .math_diff_lag("val", order_by="seq", new_column="diff", group_by="grp") + .rows_sort(["grp", "seq"]) + .process(rel, backend=backend) + ) + vals = _col_values(result, "diff") + assert vals == [None, 20, 5, None, 50, 10] + + def test_datetime_column( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql( + "SELECT * FROM (VALUES " + "(1, TIMESTAMP '2024-01-01 00:00:00'), " + "(2, TIMESTAMP '2024-01-01 01:00:00'), " + "(3, TIMESTAMP '2024-01-01 03:00:00')" + ") AS t(id, ts)" + ) + result, _ = ( + TransformPlan() + .math_diff_lag("ts", order_by="id", new_column="gap") + .process(rel, backend=backend) + ) + vals = _col_values(result, "gap") + assert vals[0] is None + assert vals[1].total_seconds() == 3600 + assert vals[2].total_seconds() == 7200 + + def test_datetime_grouped( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql( + "SELECT * FROM (VALUES " + "('A', TIMESTAMP '2024-01-01 00:00:00'), " + "('A', TIMESTAMP '2024-01-01 02:00:00'), " + "('B', TIMESTAMP '2024-01-01 10:00:00'), " + "('B', TIMESTAMP '2024-01-01 13:00:00')" + ") AS t(patient, ts)" + ) + result, _ = ( + TransformPlan() + .math_diff_lag("ts", order_by="ts", new_column="gap", group_by="patient") + .rows_sort(["patient", "ts"]) + .process(rel, backend=backend) + ) + vals = _col_values(result, "gap") + assert vals[0] is None + assert vals[1].total_seconds() / 3600 == 2.0 + assert vals[2] is None + assert vals[3].total_seconds() / 3600 == 3.0 + + def test_order_by_list( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql( + "SELECT * FROM (VALUES " + "(1, 1, 10), (1, 2, 20), (2, 1, 30), (2, 2, 40)" + ") AS t(a, b, val)" + ) + result, _ = ( + TransformPlan() + .math_diff_lag("val", order_by=["a", "b"], new_column="diff") + .process(rel, backend=backend) + ) + vals = _col_values(result, "diff") + assert vals == [None, 10, 10, 10] + + def test_global_no_group( + self, backend: DuckDBBackend, con: duckdb.DuckDBPyConnection + ) -> None: + rel = con.sql("SELECT * FROM (VALUES (3, 30), (1, 10), (2, 20)) AS t(seq, val)") + result, _ = ( + TransformPlan() + .math_diff_lag("val", order_by="seq", new_column="diff") + .process(rel, backend=backend) + ) + vals = _col_values(result, "diff") + assert vals[0] is None + assert vals[1:] == [10, 10] + + class TestColExpr: """Tests for col_expr on DuckDB backend.""" diff --git a/tests/test_math.py b/tests/test_math.py index 6798154..ab94263 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -465,3 +465,167 @@ def test_serialization_roundtrip(self, numeric_df: pl.DataFrame) -> None: result1, _ = plan.process(numeric_df) result2, _ = restored.process(numeric_df) assert result1["diff"].to_list() == result2["diff"].to_list() + + +class TestMathDiffLag: + """Tests for math_diff_lag operation.""" + + def test_numeric_basic(self) -> None: + """Test lag=1 on integers ordered by id; first row null.""" + df = pl.DataFrame({"id": [1, 2, 3, 4], "val": [10, 30, 35, 50]}) + plan = TransformPlan().math_diff_lag("val", order_by="id", new_column="diff") + result, _ = plan.process(df) + assert result["diff"].to_list() == [None, 20.0, 5.0, 15.0] + + def test_numeric_lag2(self) -> None: + """Test lag=2; first two rows null.""" + df = pl.DataFrame({"id": [1, 2, 3, 4], "val": [10, 30, 35, 50]}) + plan = TransformPlan().math_diff_lag( + "val", order_by="id", new_column="diff", lag=2 + ) + result, _ = plan.process(df) + assert result["diff"].to_list() == [None, None, 25.0, 20.0] + + def test_grouped_numeric(self) -> None: + """Test partition by group; nulls restart per group.""" + df = pl.DataFrame( + { + "grp": ["A", "A", "A", "B", "B", "B"], + "seq": [1, 2, 3, 1, 2, 3], + "val": [10, 30, 35, 100, 150, 160], + } + ) + plan = TransformPlan().math_diff_lag( + "val", order_by="seq", new_column="diff", group_by="grp" + ) + result, _ = plan.process(df) + expected = [None, 20.0, 5.0, None, 50.0, 10.0] + assert result["diff"].to_list() == expected + + def test_datetime_column(self) -> None: + """Test datetime input produces duration output.""" + df = pl.DataFrame( + { + "id": [1, 2, 3], + "ts": [ + datetime(2024, 1, 1, 0, 0), + datetime(2024, 1, 1, 1, 0), + datetime(2024, 1, 1, 3, 0), + ], + } + ) + plan = TransformPlan().math_diff_lag("ts", order_by="id", new_column="gap") + result, _ = plan.process(df) + assert result["gap"].dtype == pl.Duration + vals = result["gap"].to_list() + assert vals[0] is None + assert vals[1].total_seconds() == 3600 + assert vals[2].total_seconds() == 7200 + + def test_datetime_grouped(self) -> None: + """Test primary use case: time between events per patient.""" + df = pl.DataFrame( + { + "patient": ["A", "A", "B", "B"], + "ts": [ + datetime(2024, 1, 1, 0, 0), + datetime(2024, 1, 1, 2, 0), + datetime(2024, 1, 1, 10, 0), + datetime(2024, 1, 1, 13, 0), + ], + } + ) + plan = TransformPlan().math_diff_lag( + "ts", order_by="ts", new_column="gap", group_by="patient" + ) + result, _ = plan.process(df) + assert result["gap"].dtype == pl.Duration + vals = result["gap"].to_list() + assert vals[0] is None + assert vals[1].total_seconds() / 3600 == 2.0 + assert vals[2] is None + assert vals[3].total_seconds() / 3600 == 3.0 + + def test_order_by_different_column(self) -> None: + """Test diffing 'value' ordered by 'timestamp'.""" + df = pl.DataFrame( + { + "ts": [ + datetime(2024, 1, 1), + datetime(2024, 1, 2), + datetime(2024, 1, 3), + ], + "val": [100, 130, 125], + } + ) + plan = TransformPlan().math_diff_lag("val", order_by="ts", new_column="change") + result, _ = plan.process(df) + assert result["change"].to_list() == [None, 30.0, -5.0] + + def test_order_by_list(self) -> None: + """Test multi-column order_by.""" + df = pl.DataFrame( + { + "a": [1, 1, 2, 2], + "b": [1, 2, 1, 2], + "val": [10, 20, 30, 40], + } + ) + plan = TransformPlan().math_diff_lag( + "val", order_by=["a", "b"], new_column="diff" + ) + result, _ = plan.process(df) + assert result["diff"].to_list() == [None, 10.0, 10.0, 10.0] + + def test_global_no_group(self) -> None: + """Test no group_by, global ordering.""" + df = pl.DataFrame({"seq": [3, 1, 2], "val": [30, 10, 20]}) + plan = TransformPlan().math_diff_lag("val", order_by="seq", new_column="diff") + result, _ = plan.process(df) + # After sorting by seq: [10, 20, 30], diffs: [None, 10, 10] + assert result["diff"].to_list() == [None, 10.0, 10.0] + + def test_validation_nonexistent_column(self, numeric_df: pl.DataFrame) -> None: + """Test validation catches non-existent column.""" + plan = TransformPlan().math_diff_lag( + "nonexistent", order_by="a", new_column="diff" + ) + result = plan.validate(numeric_df) + assert not result.is_valid + assert "does not exist" in str(result.errors[0]) + + def test_validation_wrong_type(self, basic_df: pl.DataFrame) -> None: + """Test validation catches string column.""" + plan = TransformPlan().math_diff_lag("name", order_by="id", new_column="diff") + result = plan.validate(basic_df) + assert not result.is_valid + assert "numeric or datetime" in str(result.errors[0]) + + def test_validation_missing_order_by(self, numeric_df: pl.DataFrame) -> None: + """Test validation catches missing order_by column.""" + plan = TransformPlan().math_diff_lag( + "a", order_by="nonexistent", new_column="diff" + ) + result = plan.validate(numeric_df) + assert not result.is_valid + assert "Order-by" in str(result.errors[0]) + + def test_validation_missing_group_by(self, numeric_df: pl.DataFrame) -> None: + """Test validation catches missing group_by column.""" + plan = TransformPlan().math_diff_lag( + "a", order_by="a", new_column="diff", group_by="nonexistent" + ) + result = plan.validate(numeric_df) + assert not result.is_valid + assert "Group-by" in str(result.errors[0]) + + def test_serialization_roundtrip(self, numeric_df: pl.DataFrame) -> None: + """Test JSON serialization round-trip.""" + plan = TransformPlan().math_diff_lag( + "a", order_by="a", new_column="diff", group_by="b", lag=2 + ) + json_str = plan.to_json() + restored = TransformPlan.from_json(json_str) + result1, _ = plan.process(numeric_df) + result2, _ = restored.process(numeric_df) + assert result1["diff"].to_list() == result2["diff"].to_list() diff --git a/transformplan/backends/base.py b/transformplan/backends/base.py index c0ac565..0a6eff0 100644 --- a/transformplan/backends/base.py +++ b/transformplan/backends/base.py @@ -309,6 +309,17 @@ def math_diff_from_agg( group_by: list[str] | None, ) -> Any: ... + @abstractmethod + def math_diff_lag( + self, + data: Any, + column: str, + order_by: list[str], + new_column: str, + group_by: list[str] | None, + lag: int, + ) -> Any: ... + @abstractmethod def math_standardize( self, diff --git a/transformplan/backends/duckdb.py b/transformplan/backends/duckdb.py index 7c0deda..69a5a90 100644 --- a/transformplan/backends/duckdb.py +++ b/transformplan/backends/duckdb.py @@ -582,6 +582,26 @@ def math_diff_from_agg( ) return self._con.sql(f"SELECT *, {expr} FROM {_sub(data)}") + def math_diff_lag( + self, + data: duckdb.DuckDBPyRelation, + column: str, + order_by: list[str], + new_column: str, + group_by: list[str] | None, + lag: int, + ) -> duckdb.DuckDBPyRelation: + partition = "" + if group_by: + partition = "PARTITION BY " + ", ".join(_q(g) for g in group_by) + order = "ORDER BY " + ", ".join(_q(o) for o in order_by) + window = f"{partition} {order}".strip() + expr = ( + f"({_q(column)} - LAG({_q(column)}, {lag}) OVER ({window})) " + f"AS {_q(new_column)}" + ) + return self._con.sql(f"SELECT *, {expr} FROM {_sub(data)}") + def math_standardize( self, data: duckdb.DuckDBPyRelation, diff --git a/transformplan/backends/polars.py b/transformplan/backends/polars.py index 1d6768a..96fe379 100644 --- a/transformplan/backends/polars.py +++ b/transformplan/backends/polars.py @@ -366,6 +366,25 @@ def math_diff_from_agg( agg_expr = agg_expr.over(group_by) return data.with_columns((pl.col(column) - agg_expr).alias(new_column)) + def math_diff_lag( + self, + data: pl.DataFrame, + column: str, + order_by: list[str], + new_column: str, + group_by: list[str] | None, + lag: int, + ) -> pl.DataFrame: + if group_by: + expr = pl.col(column) - pl.col(column).shift(lag).over( + partition_by=group_by, order_by=order_by + ) + return data.with_columns(expr.alias(new_column)) + data = data.sort(order_by) + return data.with_columns( + (pl.col(column) - pl.col(column).shift(lag)).alias(new_column) + ) + def math_standardize( self, data: pl.DataFrame, diff --git a/transformplan/chunking.py b/transformplan/chunking.py index 453366b..e90b7ad 100644 --- a/transformplan/chunking.py +++ b/transformplan/chunking.py @@ -97,6 +97,7 @@ class OperationMeta: "math_diff_from_agg": OperationMeta( ChunkMode.GROUP_DEPENDENT, group_param="group_by" ), + "math_diff_lag": OperationMeta(ChunkMode.GROUP_DEPENDENT, group_param="group_by"), # String operations - all chunkable "str_replace": OperationMeta(ChunkMode.CHUNKABLE), "str_slice": OperationMeta(ChunkMode.CHUNKABLE), diff --git a/transformplan/ops/math.py b/transformplan/ops/math.py index 5729699..a6fb226 100644 --- a/transformplan/ops/math.py +++ b/transformplan/ops/math.py @@ -327,6 +327,46 @@ def math_diff_from_agg( }, ) + def math_diff_lag( + self, + column: str, + *, + order_by: str | list[str], + new_column: str, + group_by: str | list[str] | None = None, + lag: int = 1, + ) -> Self: + """Compute row-to-row difference using lag. + + Calculates column - LAG(column, lag) ordered by order_by and optionally + partitioned by group_by. Works on numeric columns (result is float) and + datetime columns (result is duration). + + Args: + column: Source column (numeric or datetime). + order_by: Column(s) defining row order. + new_column: Name for result column. + group_by: Column(s) to partition by. None for global ordering. + lag: Number of rows to look back (must be >= 1). + + Returns: + Self for method chaining. + """ + if isinstance(order_by, str): + order_by = [order_by] + if isinstance(group_by, str): + group_by = [group_by] + return self._register( + "math_diff_lag", + { + "column": column, + "order_by": order_by, + "new_column": new_column, + "group_by": group_by, + "lag": lag, + }, + ) + # ========================================================================= # Scaling Operations # ========================================================================= diff --git a/transformplan/validation.py b/transformplan/validation.py index b2f3a10..f78a421 100644 --- a/transformplan/validation.py +++ b/transformplan/validation.py @@ -961,6 +961,51 @@ def _validate_math_diff_from_agg( tracker.add_column(new_column, tracker.float_type) +def _validate_math_diff_lag( + tracker: SchemaTracker, params: dict[str, Any], result: ValidationResult, step: int +) -> None: + column = params["column"] + order_by = params["order_by"] + new_column = params["new_column"] + group_by = params.get("group_by") + + if _check_column_exists(tracker, column, result, step, "math_diff_lag"): + dtype = tracker.get_dtype(column) + if not (tracker.is_numeric(dtype) or tracker.is_datetime(dtype)): + result.add_error( + step, + "math_diff_lag", + f"Column '{column}' must be numeric or datetime, " + f"got {tracker.type_name(dtype)}", + ) + + missing_order = [c for c in order_by if not tracker.has_column(c)] + if missing_order: + result.add_error( + step, + "math_diff_lag", + f"Order-by columns do not exist: {missing_order}", + ) + + if group_by: + missing_group = [c for c in group_by if not tracker.has_column(c)] + if missing_group: + result.add_error( + step, + "math_diff_lag", + f"Group-by columns do not exist: {missing_group}", + ) + + if tracker.has_column(column): + dtype = tracker.get_dtype(column) + out_type = ( + tracker.duration_type if tracker.is_datetime(dtype) else tracker.float_type + ) + tracker.add_column(new_column, out_type) + else: + tracker.add_column(new_column, tracker.float_type) + + def _validate_math_percent_of( tracker: SchemaTracker, params: dict[str, Any], result: ValidationResult, step: int ) -> None: @@ -1590,6 +1635,7 @@ def _validate_map_label( "math_cumsum": _validate_math_cumsum, "math_rank": _validate_math_rank, "math_diff_from_agg": _validate_math_diff_from_agg, + "math_diff_lag": _validate_math_diff_lag, "math_percent_of": _validate_math_percent_of, # Scaling ops "math_standardize": partial(_validate_math_scaling, op_name="math_standardize"),