diff --git a/CLAUDE.md b/CLAUDE.md index c841b08..783bcf4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -97,6 +97,16 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - Alternative to Callaway-Sant'Anna with different weighting scheme - Useful robustness check when both estimators agree +- **`diff_diff/imputation.py`** - Borusyak-Jaravel-Spiess imputation DiD estimator: + - `ImputationDiD` - Borusyak et al. (2024) efficient imputation estimator for staggered DiD + - `ImputationDiDResults` - Results with overall ATT, event study, group effects, pre-trend test + - `ImputationBootstrapResults` - Multiplier bootstrap inference results + - `imputation_did()` - Convenience function + - Steps: (1) OLS on untreated obs for unit+time FE, (2) impute counterfactual Y(0), (3) aggregate + - Conservative variance (Theorem 3) with `aux_partition` parameter for SE tightness + - Pre-trend test (Equation 9) via `results.pretrend_test()` + - Proposition 5: NaN for unidentified long-run horizons without never-treated units + - **`diff_diff/triple_diff.py`** - Triple Difference (DDD) estimator: - `TripleDifference` - Ortiz-Villavicencio & Sant'Anna (2025) estimator for DDD designs - `TripleDifferenceResults` - Results with ATT, SEs, cell means, diagnostics @@ -255,6 +265,7 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. Standalone estimators (each has own get_params/set_params): ├── CallawaySantAnna ├── SunAbraham + ├── ImputationDiD ├── TripleDifference ├── TROP ├── SyntheticDiD @@ -364,6 +375,7 @@ Tests mirror the source modules: - `tests/test_estimators.py` - Tests for DifferenceInDifferences, TWFE, MultiPeriodDiD, SyntheticDiD - `tests/test_staggered.py` - Tests for CallawaySantAnna - `tests/test_sun_abraham.py` - Tests for SunAbraham interaction-weighted estimator +- `tests/test_imputation.py` - Tests for ImputationDiD (Borusyak et al. 2024) estimator - `tests/test_triple_diff.py` - Tests for Triple Difference (DDD) estimator - `tests/test_trop.py` - Tests for Triply Robust Panel (TROP) estimator - `tests/test_bacon.py` - Tests for Goodman-Bacon decomposition diff --git a/README.md b/README.md index a902ea6..55bd9ae 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1 - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights - **Panel data support**: Two-way fixed effects estimator for panel designs - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects -- **Staggered adoption**: Callaway-Sant'Anna (2021) and Sun-Abraham (2021) estimators for heterogeneous treatment timing +- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), and Borusyak-Jaravel-Spiess (2024) imputation estimators for heterogeneous treatment timing - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025) @@ -879,6 +879,54 @@ print(f"Sun-Abraham ATT: {sa_results.overall_att:.3f}") # If results differ substantially, investigate heterogeneity ``` +### Borusyak-Jaravel-Spiess Imputation Estimator + +The Borusyak et al. (2024) imputation estimator is the **efficient** estimator for staggered DiD under parallel trends, producing ~50% shorter confidence intervals than Callaway-Sant'Anna and 2-3.5x shorter than Sun-Abraham under homogeneous treatment effects. + +```python +from diff_diff import ImputationDiD, imputation_did + +# Basic usage +est = ImputationDiD() +results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat') +results.print_summary() + +# Event study +results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat', + aggregate='event_study') + +# Pre-trend test (Equation 9) +pt = results.pretrend_test(n_leads=3) +print(f"F-stat: {pt['f_stat']:.3f}, p-value: {pt['p_value']:.4f}") + +# Convenience function +results = imputation_did(data, 'outcome', 'unit', 'period', 'first_treat', + aggregate='all') +``` + +```python +ImputationDiD( + anticipation=0, # Number of anticipation periods + alpha=0.05, # Significance level + cluster=None, # Cluster variable (defaults to unit) + n_bootstrap=0, # Bootstrap iterations (0=analytical inference) + seed=None, # Random seed + horizon_max=None, # Max event-study horizon + aux_partition="cohort_horizon", # Variance partition: "cohort_horizon", "cohort", "horizon" +) +``` + +**When to use Imputation DiD vs Callaway-Sant'Anna:** + +| Aspect | Imputation DiD | Callaway-Sant'Anna | +|--------|---------------|-------------------| +| Efficiency | Most efficient under homogeneous effects | Less efficient but more robust to heterogeneity | +| Control group | Always uses all untreated obs | Choice of never-treated or not-yet-treated | +| Inference | Conservative variance (Theorem 3) | Multiplier bootstrap | +| Pre-trends | Built-in F-test (Equation 9) | Separate testing | + ### Triple Difference (DDD) Triple Difference (DDD) is used when treatment requires satisfying two criteria: belonging to a treated **group** AND being in an eligible **partition**. The `TripleDifference` class implements the methodology from Ortiz-Villavicencio & Sant'Anna (2025), which correctly handles covariate adjustment (unlike naive implementations). @@ -2000,6 +2048,60 @@ SunAbraham( | `print_summary(alpha)` | Print summary to stdout | | `to_dataframe(level)` | Convert to DataFrame ('event_study' or 'cohort') | +### ImputationDiD + +```python +ImputationDiD( + anticipation=0, # Periods of anticipation effects + alpha=0.05, # Significance level for CIs + cluster=None, # Column for cluster-robust SEs + n_bootstrap=0, # Bootstrap iterations (0 = analytical) + seed=None, # Random seed + rank_deficient_action='warn', # 'warn', 'error', or 'silent' + horizon_max=None, # Max event-study horizon + aux_partition='cohort_horizon', # Variance partition +) +``` + +**fit() Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `data` | DataFrame | Panel data | +| `outcome` | str | Outcome variable column name | +| `unit` | str | Unit identifier column | +| `time` | str | Time period column | +| `first_treat` | str | First treatment period column (0 for never-treated) | +| `covariates` | list | Covariate column names | +| `aggregate` | str | Aggregation: None, "event_study", "group", "all" | +| `balance_e` | int | Balance event study to this many pre-treatment periods | + +### ImputationDiDResults + +**Attributes:** + +| Attribute | Description | +|-----------|-------------| +| `overall_att` | Overall average treatment effect on the treated | +| `overall_se` | Standard error (conservative, Theorem 3) | +| `overall_t_stat` | T-statistic | +| `overall_p_value` | P-value for H0: ATT = 0 | +| `overall_conf_int` | Confidence interval | +| `event_study_effects` | Dict of relative time -> effect dict (if `aggregate='event_study'` or `'all'`) | +| `group_effects` | Dict of cohort -> effect dict (if `aggregate='group'` or `'all'`) | +| `treatment_effects` | DataFrame of unit-level imputed treatment effects | +| `n_treated_obs` | Number of treated observations | +| `n_untreated_obs` | Number of untreated observations | + +**Methods:** + +| Method | Description | +|--------|-------------| +| `summary(alpha)` | Get formatted summary string | +| `print_summary(alpha)` | Print summary to stdout | +| `to_dataframe(level)` | Convert to DataFrame ('observation', 'event_study', 'group') | +| `pretrend_test(n_leads)` | Run pre-trend F-test (Equation 9) | + ### TripleDifference ```python @@ -2464,6 +2566,14 @@ The `HonestDiD` module implements sensitivity analysis methods for relaxing the ### Multi-Period and Staggered Adoption +- **Borusyak, K., Jaravel, X., & Spiess, J. (2024).** "Revisiting Event-Study Designs: Robust and Efficient Estimation." *Review of Economic Studies*, 91(6), 3253-3285. [https://doi.org/10.1093/restud/rdae007](https://doi.org/10.1093/restud/rdae007) + + This paper introduces the imputation estimator implemented in our `ImputationDiD` class: + - **Efficient imputation**: OLS on untreated observations → impute counterfactuals → aggregate + - **Conservative variance**: Theorem 3 clustered variance estimator with auxiliary model + - **Pre-trend test**: Independent of treatment effect estimation (Proposition 9) + - **Efficiency gains**: ~50% shorter CIs than Callaway-Sant'Anna under homogeneous effects + - **Callaway, B., & Sant'Anna, P. H. C. (2021).** "Difference-in-Differences with Multiple Time Periods." *Journal of Econometrics*, 225(2), 200-230. [https://doi.org/10.1016/j.jeconom.2020.12.001](https://doi.org/10.1016/j.jeconom.2020.12.001) - **Sant'Anna, P. H. C., & Zhao, J. (2020).** "Doubly Robust Difference-in-Differences Estimators." *Journal of Econometrics*, 219(1), 101-122. [https://doi.org/10.1016/j.jeconom.2020.06.003](https://doi.org/10.1016/j.jeconom.2020.06.003) diff --git a/ROADMAP.md b/ROADMAP.md index b76ee5a..9020601 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -10,7 +10,7 @@ For past changes and release history, see [CHANGELOG.md](CHANGELOG.md). diff-diff v2.1.1 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis: -- **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Synthetic DiD, Triple Difference (DDD), TROP +- **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP - **Valid inference**: Robust SEs, cluster SEs, wild bootstrap, multiplier bootstrap, placebo-based variance - **Assumption diagnostics**: Parallel trends tests, placebo tests, Goodman-Bacon decomposition - **Sensitivity analysis**: Honest DiD (Rambachan-Roth), Pre-trends power analysis (Roth 2022) @@ -24,15 +24,9 @@ diff-diff v2.1.1 is a **production-ready** DiD library with feature parity with High-value additions building on our existing foundation. -### Borusyak-Jaravel-Spiess Imputation Estimator +### ~~Borusyak-Jaravel-Spiess Imputation Estimator~~ ✅ Implemented (v2.2) -More efficient than Callaway-Sant'Anna when treatment effects are homogeneous across groups/time. Uses imputation rather than aggregation. - -- Imputes untreated potential outcomes using pre-treatment data -- More efficient under homogeneous effects assumption -- Can handle unbalanced panels more naturally - -**Reference**: Borusyak, Jaravel, and Spiess (2024). *Review of Economic Studies*. +Implemented as `ImputationDiD` — see `diff_diff/imputation.py`. Includes conservative variance (Theorem 3), event study and group aggregation, pre-trend test (Equation 9), multiplier bootstrap, and Proposition 5 handling for no never-treated units. ### Gardner's Two-Stage DiD (did2s) diff --git a/benchmarks/R/benchmark_didimputation.R b/benchmarks/R/benchmark_didimputation.R new file mode 100644 index 0000000..3c4e054 --- /dev/null +++ b/benchmarks/R/benchmark_didimputation.R @@ -0,0 +1,160 @@ +#!/usr/bin/env Rscript +# Benchmark: Imputation DiD Estimator (R `didimputation` package) +# +# Compares against diff_diff.ImputationDiD (Borusyak, Jaravel & Spiess 2024). +# +# Usage: +# Rscript benchmark_didimputation.R --data path/to/data.csv --output path/to/results.json + +library(didimputation) +library(fixest) +library(jsonlite) +library(data.table) + +# Parse command line arguments +args <- commandArgs(trailingOnly = TRUE) + +parse_args <- function(args) { + result <- list( + data = NULL, + output = NULL + ) + + i <- 1 + while (i <= length(args)) { + if (args[i] == "--data") { + result$data <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--output") { + result$output <- args[i + 1] + i <- i + 2 + } else { + i <- i + 1 + } + } + + if (is.null(result$data) || is.null(result$output)) { + stop("Usage: Rscript benchmark_didimputation.R --data --output ") + } + + return(result) +} + +config <- parse_args(args) + +# Load data +message(sprintf("Loading data from: %s", config$data)) +data <- fread(config$data) + +# Ensure proper column types +data[, unit := as.integer(unit)] +data[, time := as.integer(time)] + +# R's didimputation package expects first_treat=0 or NA for never-treated units +# Our Python implementation uses first_treat=0 for never-treated, which matches +data[, first_treat := as.integer(first_treat)] +message(sprintf("Never-treated units (first_treat=0): %d", sum(data$first_treat == 0))) + +# Determine event study horizons from the data +# Compute relative time for treated units +treated_data <- data[first_treat > 0] +if (nrow(treated_data) > 0) { + treated_data[, rel_time := time - first_treat] + min_horizon <- min(treated_data$rel_time) + max_horizon <- max(treated_data$rel_time) + # Post-treatment horizons only (for event study) + post_horizons <- sort(unique(treated_data$rel_time[treated_data$rel_time >= 0])) + all_horizons <- sort(unique(treated_data$rel_time)) + message(sprintf("Horizon range: [%d, %d]", min_horizon, max_horizon)) + message(sprintf("Post-treatment horizons: %s", paste(post_horizons, collapse = ", "))) +} + +# Run benchmark - Overall ATT (static) +message("Running did_imputation (static)...") +start_time <- Sys.time() + +static_result <- did_imputation( + data = data, + yname = "outcome", + gname = "first_treat", + tname = "time", + idname = "unit", + cluster_var = "unit" +) + +static_time <- as.numeric(difftime(Sys.time(), start_time, units = "secs")) +message(sprintf("Static estimation completed in %.3f seconds", static_time)) + +# Extract overall ATT +overall_att <- static_result$estimate[1] +overall_se <- static_result$std.error[1] +message(sprintf("Overall ATT: %.6f (SE: %.6f)", overall_att, overall_se)) + +# Run benchmark - Event study +message("Running did_imputation (event study)...") +es_start_time <- Sys.time() + +es_result <- did_imputation( + data = data, + yname = "outcome", + gname = "first_treat", + tname = "time", + idname = "unit", + horizon = TRUE, + cluster_var = "unit" +) + +es_time <- as.numeric(difftime(Sys.time(), es_start_time, units = "secs")) +message(sprintf("Event study estimation completed in %.3f seconds", es_time)) + +total_time <- static_time + es_time + +# Format event study results +event_study <- data.frame( + event_time = as.integer(gsub("tau", "", es_result$term)), + att = es_result$estimate, + se = es_result$std.error +) + +message("Event study effects:") +for (i in seq_len(nrow(event_study))) { + message(sprintf(" h=%d: ATT=%.4f (SE=%.4f)", + event_study$event_time[i], + event_study$att[i], + event_study$se[i])) +} + +# Format output +results <- list( + estimator = "didimputation::did_imputation", + + # Overall ATT + overall_att = overall_att, + overall_se = overall_se, + + # Event study + event_study = event_study, + + # Timing + timing = list( + static_seconds = static_time, + event_study_seconds = es_time, + total_seconds = total_time + ), + + # Metadata + metadata = list( + r_version = R.version.string, + didimputation_version = as.character(packageVersion("didimputation")), + n_units = length(unique(data$unit)), + n_periods = length(unique(data$time)), + n_obs = nrow(data) + ) +) + +# Write output +message(sprintf("Writing results to: %s", config$output)) +dir.create(dirname(config$output), recursive = TRUE, showWarnings = FALSE) +write_json(results, config$output, auto_unbox = TRUE, pretty = TRUE, digits = 10) + +message(sprintf("Completed in %.3f seconds", total_time)) diff --git a/benchmarks/R/requirements.R b/benchmarks/R/requirements.R index 2393f4b..42e4750 100644 --- a/benchmarks/R/requirements.R +++ b/benchmarks/R/requirements.R @@ -7,6 +7,7 @@ required_packages <- c( # Core DiD packages "did", # Callaway-Sant'Anna (2021) staggered DiD + "didimputation", # Borusyak, Jaravel & Spiess (2024) imputation DiD "HonestDiD", # Rambachan & Roth (2023) sensitivity analysis "fixest", # Fast TWFE and basic DiD diff --git a/benchmarks/python/benchmark_imputation.py b/benchmarks/python/benchmark_imputation.py new file mode 100644 index 0000000..27f556f --- /dev/null +++ b/benchmarks/python/benchmark_imputation.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +""" +Benchmark: Imputation DiD Estimator (diff-diff ImputationDiD). + +Compares against R's didimputation::did_imputation (Borusyak et al. 2024). + +Usage: + python benchmark_imputation.py --data path/to/data.csv --output path/to/results.json +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff +# This ensures the backend configuration is respected by all modules +def _get_backend_from_args(): + """Parse --backend argument without importing diff_diff.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"]) + args, _ = parser.parse_known_args() + return args.backend + +_requested_backend = _get_backend_from_args() +if _requested_backend in ("python", "rust"): + os.environ["DIFF_DIFF_BACKEND"] = _requested_backend + +# NOW import diff_diff and other dependencies (will see the env var) +import pandas as pd + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from diff_diff import ImputationDiD, HAS_RUST_BACKEND +from benchmarks.python.utils import Timer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark Imputation DiD estimator" + ) + parser.add_argument("--data", required=True, help="Path to input CSV data") + parser.add_argument("--output", required=True, help="Path to output JSON results") + parser.add_argument( + "--backend", default="auto", choices=["auto", "python", "rust"], + help="Backend to use: auto (default), python (pure Python), rust (Rust backend)" + ) + return parser.parse_args() + + +def get_actual_backend() -> str: + """Return the actual backend being used based on HAS_RUST_BACKEND.""" + return "rust" if HAS_RUST_BACKEND else "python" + + +def main(): + args = parse_args() + + # Get actual backend (already configured via env var before imports) + actual_backend = get_actual_backend() + print(f"Using backend: {actual_backend}") + + # Load data + print(f"Loading data from: {args.data}") + df = pd.read_csv(args.data) + + # Run benchmark + print("Running ImputationDiD estimation...") + est = ImputationDiD( + n_bootstrap=0, # Analytical SE (conservative variance, Theorem 3) + ) + + with Timer() as estimation_timer: + results = est.fit( + df, + outcome="outcome", + time="time", + unit="unit", + first_treat="first_treat", + aggregate="event_study", + ) + + estimation_time = estimation_timer.elapsed + total_time = estimation_time + + # Store data info before looping (avoid shadowing) + n_units = len(df["unit"].unique()) + n_periods = len(df["time"].unique()) + n_obs = len(df) + + # Format event study effects (if available) + es_effects = [] + if results.event_study_effects: + for rel_t, effect_data in sorted(results.event_study_effects.items()): + # Skip reference period marker (n_obs == 0) + if effect_data.get("n_obs", 1) == 0: + continue + es_effects.append({ + "event_time": int(rel_t), + "att": float(effect_data["effect"]), + "se": float(effect_data["se"]), + }) + + # Build output + output = { + "estimator": "diff_diff.ImputationDiD", + "backend": actual_backend, + # Overall ATT + "overall_att": float(results.overall_att), + "overall_se": float(results.overall_se), + # Event study + "event_study": es_effects, + # Timing + "timing": { + "estimation_seconds": estimation_time, + "total_seconds": total_time, + }, + # Metadata + "metadata": { + "n_units": n_units, + "n_periods": n_periods, + "n_obs": n_obs, + "n_treated_obs": results.n_treated_obs, + "n_untreated_obs": results.n_untreated_obs, + }, + } + + # Write output + print(f"Writing results to: {args.output}") + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + print(f"Overall ATT: {results.overall_att:.6f} (SE: {results.overall_se:.6f})") + print(f"Completed in {total_time:.3f} seconds") + return output + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_benchmarks.py b/benchmarks/run_benchmarks.py index b4d3d41..a7cd9b1 100644 --- a/benchmarks/run_benchmarks.py +++ b/benchmarks/run_benchmarks.py @@ -957,6 +957,139 @@ def run_multiperiod_benchmark( return results +def run_imputation_benchmark( + data_path: Path, + name: str = "imputation", + scale: str = "small", + n_replications: int = 1, + backends: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Run Imputation DiD benchmarks (Python and R) with replications.""" + print(f"\n{'='*60}") + print(f"IMPUTATION DID BENCHMARK ({scale})") + print(f"{'='*60}") + + if backends is None: + backends = ["python", "rust"] + + timeouts = TIMEOUT_CONFIGS.get(scale, TIMEOUT_CONFIGS["small"]) + results = { + "name": name, + "scale": scale, + "n_replications": n_replications, + "python_pure": None, + "python_rust": None, + "r": None, + "comparison": None, + } + + # Run Python benchmark for each backend + for backend in backends: + # Map backend name to label (python -> pure, rust -> rust) + backend_label = f"python_{'pure' if backend == 'python' else backend}" + print(f"\nRunning Python (diff_diff.ImputationDiD, backend={backend}) - {n_replications} replications...") + py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" + py_output.parent.mkdir(parents=True, exist_ok=True) + + py_timings = [] + py_result = None + for rep in range(n_replications): + try: + py_result = run_python_benchmark( + "benchmark_imputation.py", data_path, py_output, + timeout=timeouts["python"], + backend=backend, + ) + py_timings.append(py_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {py_result['overall_att']:.4f}") + print(f" SE: {py_result['overall_se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if py_result and py_timings: + timing_stats = compute_timing_stats(py_timings) + py_result["timing"] = timing_stats + results[backend_label] = py_result + print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + + # For backward compatibility, also store as "python" (use rust if available) + if results.get("python_rust"): + results["python"] = results["python_rust"] + elif results.get("python_pure"): + results["python"] = results["python_pure"] + + # R benchmark with replications + print(f"\nRunning R (didimputation::did_imputation) - {n_replications} replications...") + r_output = RESULTS_DIR / "accuracy" / f"r_{name}_{scale}.json" + + r_timings = [] + r_result = None + for rep in range(n_replications): + try: + r_result = run_r_benchmark( + "benchmark_didimputation.R", data_path, r_output, + timeout=timeouts["r"] + ) + r_timings.append(r_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {r_result['overall_att']:.4f}") + print(f" SE: {r_result['overall_se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {r_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if r_result and r_timings: + timing_stats = compute_timing_stats(r_timings) + r_result["timing"] = timing_stats + results["r"] = r_result + print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + + # Compare results + if results.get("python") and results.get("r"): + print("\nComparison (Python vs R):") + comparison = compare_estimates( + results["python"], results["r"], "ImputationDiD", scale=scale, + python_pure_results=results.get("python_pure"), + python_rust_results=results.get("python_rust"), + ) + results["comparison"] = comparison + print(f" ATT diff: {comparison.att_diff:.2e}") + print(f" SE rel diff: {comparison.se_rel_diff:.1%}") + print(f" Status: {'PASS' if comparison.passed else 'FAIL'}") + + # Event study comparison + py_effects = results["python"].get("event_study", []) + r_effects = results["r"].get("event_study", []) + if py_effects and r_effects: + corr, max_diff, all_close = compare_event_study(py_effects, r_effects) + print(f" Event study correlation: {corr:.6f}") + print(f" Event study max diff: {max_diff:.2e}") + print(f" Event study all close: {all_close}") + + # Print timing comparison table + print("\nTiming Comparison:") + print(f" {'Backend':<15} {'Time (s)':<12} {'vs R':<12} {'vs Pure Python':<15}") + print(f" {'-'*54}") + + r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None + pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + + if r_mean: + print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") + if pure_mean: + r_speedup = f"{r_mean/pure_mean:.2f}x" if r_mean else "-" + print(f" {'Python (pure)':<15} {pure_mean:<12.3f} {r_speedup:<12} {'1.00x':<15}") + if rust_mean: + r_speedup = f"{r_mean/rust_mean:.2f}x" if r_mean else "-" + pure_speedup = f"{pure_mean/rust_mean:.2f}x" if pure_mean else "-" + print(f" {'Python (rust)':<15} {rust_mean:<12.3f} {r_speedup:<12} {pure_speedup:<15}") + + return results + + def main(): parser = argparse.ArgumentParser( description="Run diff-diff benchmarks against R packages" @@ -968,7 +1101,7 @@ def main(): ) parser.add_argument( "--estimator", - choices=["callaway", "synthdid", "basic", "twfe", "multiperiod"], + choices=["callaway", "synthdid", "basic", "twfe", "multiperiod", "imputation"], help="Run specific estimator benchmark", ) parser.add_argument( @@ -1079,6 +1212,17 @@ def main(): ) all_results.append(results) + if args.all or args.estimator == "imputation": + # Imputation DiD uses the same staggered data as Callaway-Sant'Anna + stag_key = f"staggered_{scale}" + if stag_key in datasets: + results = run_imputation_benchmark( + datasets[stag_key], + scale=scale, + n_replications=args.replications, + ) + all_results.append(results) + # Generate summary report if all_results: print(f"\n{'='*60}") diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index b2f55b5..cfc9ace 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -95,6 +95,12 @@ CSBootstrapResults, GroupTimeEffect, ) +from diff_diff.imputation import ( + ImputationBootstrapResults, + ImputationDiD, + ImputationDiDResults, + imputation_did, +) from diff_diff.sun_abraham import ( SABootstrapResults, SunAbraham, @@ -145,6 +151,7 @@ "SyntheticDiD", "CallawaySantAnna", "SunAbraham", + "ImputationDiD", "TripleDifference", "TROP", # Bacon Decomposition @@ -163,6 +170,9 @@ "GroupTimeEffect", "SunAbrahamResults", "SABootstrapResults", + "ImputationDiDResults", + "ImputationBootstrapResults", + "imputation_did", "TripleDifferenceResults", "triple_difference", "TROPResults", diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py new file mode 100644 index 0000000..28455ab --- /dev/null +++ b/diff_diff/imputation.py @@ -0,0 +1,2480 @@ +""" +Borusyak-Jaravel-Spiess (2024) Imputation DiD Estimator. + +Implements the efficient imputation estimator for staggered +Difference-in-Differences from Borusyak, Jaravel & Spiess (2024), +"Revisiting Event-Study Designs: Robust and Efficient Estimation", +Review of Economic Studies. + +The estimator: +1. Runs OLS on untreated observations to estimate unit + time fixed effects +2. Imputes counterfactual Y(0) for treated observations +3. Aggregates imputed treatment effects with researcher-chosen weights + +Inference uses the conservative clustered variance estimator (Theorem 3). +""" + +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import pandas as pd +from scipy import sparse, stats +from scipy.sparse.linalg import spsolve + +from diff_diff.linalg import solve_ols +from diff_diff.results import _get_significance_stars +from diff_diff.utils import compute_confidence_interval, compute_p_value + +# ============================================================================= +# Results Dataclasses +# ============================================================================= + + +@dataclass +class ImputationBootstrapResults: + """ + Results from ImputationDiD bootstrap inference. + + Bootstrap is a library extension beyond Borusyak et al. (2024), which + proposes only analytical inference via the conservative variance estimator. + Provided for consistency with CallawaySantAnna and SunAbraham. + + Attributes + ---------- + n_bootstrap : int + Number of bootstrap iterations. + weight_type : str + Type of bootstrap weights (currently "rademacher" only). + alpha : float + Significance level used for confidence intervals. + overall_att_se : float + Bootstrap standard error for overall ATT. + overall_att_ci : tuple + Bootstrap confidence interval for overall ATT. + overall_att_p_value : float + Bootstrap p-value for overall ATT. + event_study_ses : dict, optional + Bootstrap SEs for event study effects. + event_study_cis : dict, optional + Bootstrap CIs for event study effects. + event_study_p_values : dict, optional + Bootstrap p-values for event study effects. + group_ses : dict, optional + Bootstrap SEs for group effects. + group_cis : dict, optional + Bootstrap CIs for group effects. + group_p_values : dict, optional + Bootstrap p-values for group effects. + bootstrap_distribution : np.ndarray, optional + Full bootstrap distribution of overall ATT. + """ + + n_bootstrap: int + weight_type: str + alpha: float + overall_att_se: float + overall_att_ci: Tuple[float, float] + overall_att_p_value: float + event_study_ses: Optional[Dict[int, float]] = None + event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None + event_study_p_values: Optional[Dict[int, float]] = None + group_ses: Optional[Dict[Any, float]] = None + group_cis: Optional[Dict[Any, Tuple[float, float]]] = None + group_p_values: Optional[Dict[Any, float]] = None + bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) + + +@dataclass +class ImputationDiDResults: + """ + Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation. + + Attributes + ---------- + treatment_effects : pd.DataFrame + Unit-level treatment effects with columns: unit, time, tau_hat, weight. + overall_att : float + Overall average treatment effect on the treated. + overall_se : float + Standard error of overall ATT. + overall_t_stat : float + T-statistic for overall ATT. + overall_p_value : float + P-value for overall ATT. + overall_conf_int : tuple + Confidence interval for overall ATT. + event_study_effects : dict, optional + Dictionary mapping relative time h to effect dict with keys: + 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. + group_effects : dict, optional + Dictionary mapping cohort g to effect dict. + groups : list + List of treatment cohorts. + time_periods : list + List of all time periods. + n_obs : int + Total number of observations. + n_treated_obs : int + Number of treated observations (|Omega_1|). + n_untreated_obs : int + Number of untreated observations (|Omega_0|). + n_treated_units : int + Number of ever-treated units. + n_control_units : int + Number of units contributing to Omega_0. + alpha : float + Significance level used. + pretrend_results : dict, optional + Populated by pretrend_test(). + bootstrap_results : ImputationBootstrapResults, optional + Bootstrap inference results. + """ + + treatment_effects: pd.DataFrame + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + event_study_effects: Optional[Dict[int, Dict[str, Any]]] + group_effects: Optional[Dict[Any, Dict[str, Any]]] + groups: List[Any] + time_periods: List[Any] + n_obs: int + n_treated_obs: int + n_untreated_obs: int + n_treated_units: int + n_control_units: int + alpha: float = 0.05 + pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False) + bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False) + # Internal: stores data needed for pretrend_test() + _estimator_ref: Optional[Any] = field(default=None, repr=False) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.overall_p_value) + return ( + f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, " + f"SE={self.overall_se:.4f}, " + f"n_groups={len(self.groups)}, " + f"n_treated_obs={self.n_treated_obs})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate formatted summary of estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level. Defaults to alpha used in estimation. + + Returns + ------- + str + Formatted summary. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 85, + "Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85), + "=" * 85, + "", + f"{'Total observations:':<30} {self.n_obs:>10}", + f"{'Treated observations:':<30} {self.n_treated_obs:>10}", + f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", + f"{'Treated units:':<30} {self.n_treated_units:>10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Treatment cohorts:':<30} {len(self.groups):>10}", + f"{'Time periods:':<30} {len(self.time_periods):>10}", + "", + ] + + # Overall ATT + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + t_str = ( + f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" + ) + p_str = ( + f"{self.overall_p_value:>10.4f}" + if np.isfinite(self.overall_p_value) + else f"{'NaN':>10}" + ) + sig = _get_significance_stars(self.overall_p_value) + + lines.extend( + [ + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{t_str} {p_str} {sig:>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: " + f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) + + # Event study effects + if self.event_study_effects: + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for h in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[h] + if eff.get("n_obs", 1) == 0: + # Reference period marker + lines.append( + f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" + ) + elif np.isnan(eff["effect"]): + lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + e_sig = _get_significance_stars(eff["p_value"]) + e_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + e_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{e_t} {e_p} {e_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Group effects + if self.group_effects: + lines.extend( + [ + "-" * 85, + "Group (Cohort) Effects".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for g in sorted(self.group_effects.keys()): + eff = self.group_effects[g] + if np.isnan(eff["effect"]): + lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + g_sig = _get_significance_stars(eff["p_value"]) + g_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + g_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{g_t} {g_p} {g_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Pre-trend test + if self.pretrend_results is not None: + pt = self.pretrend_results + lines.extend( + [ + "-" * 85, + "Pre-Trend Test (Equation 9)".center(85), + "-" * 85, + f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}", + f"{'P-value:':<30} {pt['p_value']:>10.4f}", + f"{'Degrees of freedom:':<30} {pt['df']:>10}", + f"{'Number of leads:':<30} {pt['n_leads']:>10}", + "-" * 85, + "", + ] + ) + + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "observation") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="observation" + Level of aggregation: + - "observation": Unit-level treatment effects + - "event_study": Event study effects by relative time + - "group": Group (cohort) effects + + Returns + ------- + pd.DataFrame + Results as DataFrame. + """ + if level == "observation": + return self.treatment_effects.copy() + + elif level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. " + "Use aggregate='event_study' or aggregate='all'." + ) + rows = [] + for h, data in sorted(self.event_study_effects.items()): + rows.append( + { + "relative_period": h, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + elif level == "group": + if self.group_effects is None: + raise ValueError( + "Group effects not computed. " "Use aggregate='group' or aggregate='all'." + ) + rows = [] + for g, data in sorted(self.group_effects.items()): + rows.append( + { + "group": g, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + else: + raise ValueError( + f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'." + ) + + def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]: + """ + Run a pre-trend test (Equation 9 of Borusyak et al. 2024). + + Adds pre-treatment lead indicators to the Step 1 OLS and tests + their joint significance via a cluster-robust Wald F-test. + + Parameters + ---------- + n_leads : int, optional + Number of pre-treatment leads to include. If None, uses all + available pre-treatment periods minus one (for the reference period). + + Returns + ------- + dict + Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads', + 'lead_coefficients'. + """ + if self._estimator_ref is None: + raise RuntimeError( + "Pre-trend test requires internal estimator reference. " + "Re-fit the model to use this method." + ) + result = self._estimator_ref._pretrend_test(n_leads=n_leads) + self.pretrend_results = result + return result + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_p_value) + + +# ============================================================================= +# Main Estimator +# ============================================================================= + + +class ImputationDiD: + """ + Borusyak-Jaravel-Spiess (2024) imputation DiD estimator. + + This is the efficient estimator for staggered Difference-in-Differences + under parallel trends. It produces shorter confidence intervals than + Callaway-Sant'Anna (~50% shorter) and Sun-Abraham (2-3.5x shorter) + under homogeneous treatment effects. + + The estimation procedure: + 1. Run OLS on untreated observations to estimate unit + time fixed effects + 2. Impute counterfactual Y(0) for treated observations + 3. Aggregate imputed treatment effects with researcher-chosen weights + + Inference uses the conservative clustered variance estimator from Theorem 3 + of the paper. + + Parameters + ---------- + anticipation : int, default=0 + Number of periods before treatment where effects may occur. + alpha : float, default=0.05 + Significance level for confidence intervals. + cluster : str, optional + Column name for cluster-robust standard errors. + If None, clusters at the unit level by default. + n_bootstrap : int, default=0 + Number of bootstrap iterations. If 0, uses analytical inference + (conservative variance from Theorem 3). + seed : int, optional + Random seed for reproducibility. + rank_deficient_action : str, default="warn" + Action when design matrix is rank-deficient: + - "warn": Issue warning and drop linearly dependent columns + - "error": Raise ValueError + - "silent": Drop columns silently + horizon_max : int, optional + Maximum event-study horizon. If set, event study effects are only + computed for |h| <= horizon_max. + aux_partition : str, default="cohort_horizon" + Controls the auxiliary model partition for Theorem 3 variance: + - "cohort_horizon": Groups by cohort x relative time (tightest SEs) + - "cohort": Groups by cohort only (more conservative) + - "horizon": Groups by relative time only (more conservative) + + Attributes + ---------- + results_ : ImputationDiDResults + Estimation results after calling fit(). + is_fitted_ : bool + Whether the model has been fitted. + + Examples + -------- + Basic usage: + + >>> from diff_diff import ImputationDiD, generate_staggered_data + >>> data = generate_staggered_data(n_units=200, seed=42) + >>> est = ImputationDiD() + >>> results = est.fit(data, outcome='outcome', unit='unit', + ... time='time', first_treat='first_treat') + >>> results.print_summary() + + With event study: + + >>> est = ImputationDiD() + >>> results = est.fit(data, outcome='outcome', unit='unit', + ... time='time', first_treat='first_treat', + ... aggregate='event_study') + >>> from diff_diff import plot_event_study + >>> plot_event_study(results) + + Notes + ----- + The imputation estimator uses ALL untreated observations (never-treated + + not-yet-treated periods of eventually-treated units) to estimate the + counterfactual model. There is no ``control_group`` parameter because this + is fundamental to the method's efficiency. + + References + ---------- + Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event-Study + Designs: Robust and Efficient Estimation. Review of Economic Studies, + 91(6), 3253-3285. + """ + + def __init__( + self, + anticipation: int = 0, + alpha: float = 0.05, + cluster: Optional[str] = None, + n_bootstrap: int = 0, + seed: Optional[int] = None, + rank_deficient_action: str = "warn", + horizon_max: Optional[int] = None, + aux_partition: str = "cohort_horizon", + ): + if rank_deficient_action not in ("warn", "error", "silent"): + raise ValueError( + f"rank_deficient_action must be 'warn', 'error', or 'silent', " + f"got '{rank_deficient_action}'" + ) + if aux_partition not in ("cohort_horizon", "cohort", "horizon"): + raise ValueError( + f"aux_partition must be 'cohort_horizon', 'cohort', or 'horizon', " + f"got '{aux_partition}'" + ) + + self.anticipation = anticipation + self.alpha = alpha + self.cluster = cluster + self.n_bootstrap = n_bootstrap + self.seed = seed + self.rank_deficient_action = rank_deficient_action + self.horizon_max = horizon_max + self.aux_partition = aux_partition + + self.is_fitted_ = False + self.results_: Optional[ImputationDiDResults] = None + + # Internal state preserved for pretrend_test() + self._fit_data: Optional[Dict[str, Any]] = None + + def fit( + self, + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]] = None, + aggregate: Optional[str] = None, + balance_e: Optional[int] = None, + ) -> ImputationDiDResults: + """ + Fit the imputation DiD estimator. + + Parameters + ---------- + data : pd.DataFrame + Panel data with unit and time identifiers. + outcome : str + Name of outcome variable column. + unit : str + Name of unit identifier column. + time : str + Name of time period column. + first_treat : str + Name of column indicating when unit was first treated. + Use 0 (or np.inf) for never-treated units. + covariates : list of str, optional + List of covariate column names. + aggregate : str, optional + Aggregation mode: None/"simple" (overall ATT only), + "event_study", "group", or "all". + balance_e : int, optional + When computing event study, restrict to cohorts observed at all + relative times in [-balance_e, max_h]. + + Returns + ------- + ImputationDiDResults + Object containing all estimation results. + + Raises + ------ + ValueError + If required columns are missing or data validation fails. + """ + # Validate inputs + required_cols = [outcome, unit, time, first_treat] + if covariates: + required_cols.extend(covariates) + + missing = [c for c in required_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing columns: {missing}") + + # Create working copy + df = data.copy() + + # Ensure numeric types + df[time] = pd.to_numeric(df[time]) + df[first_treat] = pd.to_numeric(df[first_treat]) + + # Validate absorbing treatment: first_treat must be constant within each unit + ft_nunique = df.groupby(unit)[first_treat].nunique() + non_constant = ft_nunique[ft_nunique > 1] + if len(non_constant) > 0: + example_unit = non_constant.index[0] + example_vals = sorted(df.loc[df[unit] == example_unit, first_treat].unique()) + warnings.warn( + f"{len(non_constant)} unit(s) have non-constant '{first_treat}' " + f"values (e.g., unit '{example_unit}' has values {example_vals}). " + f"ImputationDiD assumes treatment is an absorbing state " + f"(once treated, always treated) with a single treatment onset " + f"time per unit. Non-constant first_treat violates this assumption " + f"and may produce unreliable estimates.", + UserWarning, + stacklevel=2, + ) + + # Coerce to per-unit value so downstream code + # (_never_treated, _treated, _rel_time) uses a single + # consistent first_treat per unit. + df[first_treat] = df.groupby(unit)[first_treat].transform("first") + + # Identify treatment status + df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf) + + # Check for always-treated units (treated in all observed periods) + min_time = df[time].min() + always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time) + n_always_treated = df.loc[always_treated_mask, unit].nunique() + if n_always_treated > 0: + warnings.warn( + f"{n_always_treated} unit(s) are treated in all observed periods " + f"(first_treat <= {min_time}). These units have no untreated " + "observations and cannot contribute to the counterfactual model. " + "Their treatment effects will be imputed but may be unreliable.", + UserWarning, + stacklevel=2, + ) + + # Create treatment indicator D_it + # D_it = 1 if t >= first_treat and first_treat > 0 + # With anticipation: D_it = 1 if t >= first_treat - anticipation + effective_treat = df[first_treat] - self.anticipation + df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat) + + # Identify Omega_0 (untreated) and Omega_1 (treated) + omega_0_mask = ~df["_treated"] + omega_1_mask = df["_treated"] + + n_omega_0 = int(omega_0_mask.sum()) + n_omega_1 = int(omega_1_mask.sum()) + + if n_omega_0 == 0: + raise ValueError( + "No untreated observations found. Cannot estimate counterfactual model." + ) + if n_omega_1 == 0: + raise ValueError("No treated observations found. Nothing to estimate.") + + # Identify groups and time periods + time_periods = sorted(df[time].unique()) + treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf]) + + if len(treatment_groups) == 0: + raise ValueError("No treated units found. Check 'first_treat' column.") + + # Unit info + unit_info = ( + df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index() + ) + n_treated_units = int((~unit_info["_never_treated"]).sum()) + # Control units = units with at least one untreated observation + units_in_omega_0 = df.loc[omega_0_mask, unit].unique() + n_control_units = len(units_in_omega_0) + + # Cluster variable + cluster_var = self.cluster if self.cluster is not None else unit + if self.cluster is not None and self.cluster not in df.columns: + raise ValueError( + f"Cluster column '{self.cluster}' not found in data. " + f"Available columns: {list(df.columns)}" + ) + + # Compute relative time + df["_rel_time"] = np.where( + ~df["_never_treated"], + df[time] - df[first_treat], + np.nan, + ) + + # ---- Step 1: OLS on untreated observations ---- + unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model( + df, outcome, unit, time, covariates, omega_0_mask + ) + + # ---- Rank condition checks ---- + # Check: every treated unit should have >= 1 untreated period (for unit FE) + treated_unit_ids = df.loc[omega_1_mask, unit].unique() + units_with_fe = set(unit_fe.keys()) + units_missing_fe = set(treated_unit_ids) - units_with_fe + + # Check: every post-treatment period should have >= 1 untreated unit (for time FE) + post_period_ids = df.loc[omega_1_mask, time].unique() + periods_with_fe = set(time_fe.keys()) + periods_missing_fe = set(post_period_ids) - periods_with_fe + + if units_missing_fe or periods_missing_fe: + parts = [] + if units_missing_fe: + sorted_missing = sorted(units_missing_fe) + parts.append( + f"{len(units_missing_fe)} treated unit(s) have no untreated " + f"periods (units: {sorted_missing[:5]}" + f"{'...' if len(units_missing_fe) > 5 else ''})" + ) + if periods_missing_fe: + sorted_missing = sorted(periods_missing_fe) + parts.append( + f"{len(periods_missing_fe)} post-treatment period(s) have no " + f"untreated units (periods: {sorted_missing[:5]}" + f"{'...' if len(periods_missing_fe) > 5 else ''})" + ) + msg = ( + "Rank condition violated: " + + "; ".join(parts) + + ". Affected treatment effects will be NaN." + ) + if self.rank_deficient_action == "error": + raise ValueError(msg) + elif self.rank_deficient_action == "warn": + warnings.warn(msg, UserWarning, stacklevel=2) + # "silent": continue without warning + + # ---- Step 2: Impute treatment effects ---- + tau_hat, y_hat_0 = self._impute_treatment_effects( + df, + outcome, + unit, + time, + covariates, + omega_1_mask, + unit_fe, + time_fe, + grand_mean, + delta_hat, + ) + + # Store tau_hat in dataframe + df["_tau_hat"] = np.nan + df.loc[omega_1_mask, "_tau_hat"] = tau_hat + + # ---- Step 3: Aggregate ---- + # Always compute overall ATT (simple aggregation) + valid_tau = tau_hat[np.isfinite(tau_hat)] + + if len(valid_tau) == 0: + overall_att = np.nan + else: + overall_att = float(np.mean(valid_tau)) + + # ---- Conservative variance (Theorem 3) ---- + # Build weights matching the ATT: uniform over finite tau_hat, zero for NaN + overall_weights = np.zeros(n_omega_1) + finite_mask = np.isfinite(tau_hat) + n_valid = int(finite_mask.sum()) + if n_valid > 0: + overall_weights[finite_mask] = 1.0 / n_valid + + if n_valid == 0: + overall_se = np.nan + else: + overall_se = self._compute_conservative_variance( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + weights=overall_weights, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + + overall_t = ( + overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan + ) + overall_p = compute_p_value(overall_t) + overall_ci = ( + compute_confidence_interval(overall_att, overall_se, self.alpha) + if np.isfinite(overall_se) and overall_se > 0 + else (np.nan, np.nan) + ) + + # Event study and group aggregation + event_study_effects = None + group_effects = None + + if aggregate in ("event_study", "all"): + event_study_effects = self._aggregate_event_study( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + treatment_groups=treatment_groups, + balance_e=balance_e, + kept_cov_mask=kept_cov_mask, + ) + + if aggregate in ("group", "all"): + group_effects = self._aggregate_group( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + treatment_groups=treatment_groups, + kept_cov_mask=kept_cov_mask, + ) + + # Build treatment effects dataframe + treated_df = df.loc[omega_1_mask, [unit, time, "_tau_hat", "_rel_time"]].copy() + treated_df = treated_df.rename(columns={"_tau_hat": "tau_hat", "_rel_time": "rel_time"}) + # Weights consistent with actual ATT: zero for NaN tau_hat, 1/n_valid for finite + tau_finite = treated_df["tau_hat"].notna() + n_valid_te = int(tau_finite.sum()) + if n_valid_te > 0: + treated_df["weight"] = np.where(tau_finite, 1.0 / n_valid_te, 0.0) + else: + treated_df["weight"] = 0.0 + + # Store fit data for pretrend_test + self._fit_data = { + "df": df, + "outcome": outcome, + "unit": unit, + "time": time, + "first_treat": first_treat, + "covariates": covariates, + "omega_0_mask": omega_0_mask, + "omega_1_mask": omega_1_mask, + "cluster_var": cluster_var, + "unit_fe": unit_fe, + "time_fe": time_fe, + "grand_mean": grand_mean, + "delta_hat": delta_hat, + "kept_cov_mask": kept_cov_mask, + } + + # Pre-compute cluster psi sums for bootstrap + psi_data = None + if self.n_bootstrap > 0 and n_valid > 0: + try: + psi_data = self._precompute_bootstrap_psi( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + overall_weights=overall_weights, + event_study_effects=event_study_effects, + group_effects=group_effects, + treatment_groups=treatment_groups, + tau_hat=tau_hat, + balance_e=balance_e, + ) + except Exception as e: + warnings.warn( + f"Bootstrap pre-computation failed: {e}. " "Skipping bootstrap inference.", + UserWarning, + stacklevel=2, + ) + psi_data = None + + # Bootstrap + bootstrap_results = None + if self.n_bootstrap > 0 and psi_data is not None: + bootstrap_results = self._run_bootstrap( + original_att=overall_att, + original_event_study=event_study_effects, + original_group=group_effects, + psi_data=psi_data, + ) + + # Update inference with bootstrap results + overall_se = bootstrap_results.overall_att_se + overall_t = ( + overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan + ) + overall_p = bootstrap_results.overall_att_p_value + overall_ci = bootstrap_results.overall_att_ci + + # Update event study + if event_study_effects and bootstrap_results.event_study_ses: + for h in event_study_effects: + if ( + h in bootstrap_results.event_study_ses + and event_study_effects[h].get("n_obs", 1) > 0 + ): + event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h] + event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[h] + event_study_effects[h]["p_value"] = bootstrap_results.event_study_p_values[ + h + ] + eff_val = event_study_effects[h]["effect"] + se_val = event_study_effects[h]["se"] + event_study_effects[h]["t_stat"] = ( + eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan + ) + + # Update group effects + if group_effects and bootstrap_results.group_ses: + for g in group_effects: + if g in bootstrap_results.group_ses: + group_effects[g]["se"] = bootstrap_results.group_ses[g] + group_effects[g]["conf_int"] = bootstrap_results.group_cis[g] + group_effects[g]["p_value"] = bootstrap_results.group_p_values[g] + eff_val = group_effects[g]["effect"] + se_val = group_effects[g]["se"] + group_effects[g]["t_stat"] = ( + eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan + ) + + # Construct results + self.results_ = ImputationDiDResults( + treatment_effects=treated_df, + overall_att=overall_att, + overall_se=overall_se, + overall_t_stat=overall_t, + overall_p_value=overall_p, + overall_conf_int=overall_ci, + event_study_effects=event_study_effects, + group_effects=group_effects, + groups=treatment_groups, + time_periods=time_periods, + n_obs=len(df), + n_treated_obs=n_omega_1, + n_untreated_obs=n_omega_0, + n_treated_units=n_treated_units, + n_control_units=n_control_units, + alpha=self.alpha, + bootstrap_results=bootstrap_results, + _estimator_ref=self, + ) + + self.is_fitted_ = True + return self.results_ + + # ========================================================================= + # Step 1: OLS on untreated observations + # ========================================================================= + + def _iterative_fe( + self, + y: np.ndarray, + unit_vals: np.ndarray, + time_vals: np.ndarray, + idx: pd.Index, + max_iter: int = 100, + tol: float = 1e-10, + ) -> Tuple[Dict[Any, float], Dict[Any, float]]: + """ + Estimate unit and time FE via iterative alternating projection (Gauss-Seidel). + + Converges to the exact OLS solution for both balanced and unbalanced panels. + For balanced panels, converges in 1-2 iterations (identical to one-pass). + For unbalanced panels, typically 5-20 iterations. + + Returns + ------- + unit_fe : dict + Mapping from unit -> unit fixed effect. + time_fe : dict + Mapping from time -> time fixed effect. + """ + n = len(y) + alpha = np.zeros(n) # unit FE broadcast to obs level + beta = np.zeros(n) # time FE broadcast to obs level + + with np.errstate(invalid="ignore", divide="ignore"): + for iteration in range(max_iter): + # Update time FE: beta_t = mean_i(y_it - alpha_i) + resid_after_alpha = y - alpha + beta_new = ( + pd.Series(resid_after_alpha, index=idx) + .groupby(time_vals) + .transform("mean") + .values + ) + + # Update unit FE: alpha_i = mean_t(y_it - beta_t) + resid_after_beta = y - beta_new + alpha_new = ( + pd.Series(resid_after_beta, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) + + # Check convergence on FE changes + max_change = max( + np.max(np.abs(alpha_new - alpha)), + np.max(np.abs(beta_new - beta)), + ) + alpha = alpha_new + beta = beta_new + if max_change < tol: + break + + unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict() + time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict() + return unit_fe, time_fe + + @staticmethod + def _iterative_demean( + vals: np.ndarray, + unit_vals: np.ndarray, + time_vals: np.ndarray, + idx: pd.Index, + max_iter: int = 100, + tol: float = 1e-10, + ) -> np.ndarray: + """Demean a vector by iterative alternating projection (unit + time FE removal). + + Converges to the exact within-transformation for both balanced and + unbalanced panels. For balanced panels, converges in 1-2 iterations. + """ + result = vals.copy() + with np.errstate(invalid="ignore", divide="ignore"): + for _ in range(max_iter): + time_means = ( + pd.Series(result, index=idx).groupby(time_vals).transform("mean").values + ) + result_after_time = result - time_means + unit_means = ( + pd.Series(result_after_time, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) + result_new = result_after_time - unit_means + if np.max(np.abs(result_new - result)) < tol: + result = result_new + break + result = result_new + return result + + @staticmethod + def _compute_balanced_cohort_mask( + df_treated: pd.DataFrame, + first_treat: str, + all_horizons: List[int], + balance_e: int, + cohort_rel_times: Dict[Any, Set[int]], + ) -> np.ndarray: + """Compute boolean mask selecting treated obs from balanced cohorts. + + A cohort is 'balanced' if it has observations at every relative time + in [-balance_e, max(all_horizons)]. + + Parameters + ---------- + df_treated : pd.DataFrame + Post-treatment observations (Omega_1). + first_treat : str + Column name for cohort identifier. + all_horizons : list of int + Post-treatment horizons in the event study. + balance_e : int + Number of pre-treatment periods to require. + cohort_rel_times : dict + Maps each cohort value to the set of all observed relative times + (including pre-treatment) from the full panel. Built by + _build_cohort_rel_times(). + """ + if not all_horizons: + return np.ones(len(df_treated), dtype=bool) + + max_h = max(all_horizons) + required_range = set(range(-balance_e, max_h + 1)) + + balanced_cohorts = set() + for g, horizons in cohort_rel_times.items(): + if required_range.issubset(horizons): + balanced_cohorts.add(g) + + return df_treated[first_treat].isin(balanced_cohorts).values + + @staticmethod + def _build_cohort_rel_times( + df: pd.DataFrame, + first_treat: str, + ) -> Dict[Any, Set[int]]: + """Build mapping of cohort -> set of observed relative times from full panel. + + Precondition: df must have '_never_treated' and '_rel_time' columns + (set by fit() before any aggregation calls). + """ + treated_mask = ~df["_never_treated"] + treated_df = df.loc[treated_mask] + result: Dict[Any, Set[int]] = {} + ft_vals = treated_df[first_treat].values + rt_vals = treated_df["_rel_time"].values + for i in range(len(treated_df)): + h = rt_vals[i] + if np.isfinite(h): + result.setdefault(ft_vals[i], set()).add(int(h)) + return result + + def _fit_untreated_model( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + ) -> Tuple[ + Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray] + ]: + """ + Step 1: Estimate unit + time FE on untreated observations. + + Uses iterative alternating projection (Gauss-Seidel) to compute exact + OLS fixed effects for both balanced and unbalanced panels. For balanced + panels, converges in 1-2 iterations (identical to one-pass demeaning). + + Returns + ------- + unit_fe : dict + Unit fixed effects {unit_id: alpha_i}. + time_fe : dict + Time fixed effects {time_period: beta_t}. + grand_mean : float + Grand mean (0.0 — absorbed into iterative FE). + delta_hat : np.ndarray or None + Covariate coefficients (if covariates provided). + kept_cov_mask : np.ndarray or None + Boolean mask of shape (n_covariates,) indicating which covariates + have finite coefficients. None if no covariates. + """ + df_0 = df.loc[omega_0_mask] + + if covariates is None or len(covariates) == 0: + # No covariates: estimate FE via iterative alternating projection + # (exact OLS for both balanced and unbalanced panels) + y = df_0[outcome].values.copy() + unit_fe, time_fe = self._iterative_fe( + y, df_0[unit].values, df_0[time].values, df_0.index + ) + # grand_mean = 0: iterative FE absorb the intercept + return unit_fe, time_fe, 0.0, None, None + + else: + # With covariates: iteratively demean Y and X, OLS for delta, + # then recover FE from covariate-adjusted outcome + y = df_0[outcome].values.copy() + X_raw = df_0[covariates].values.copy() + units = df_0[unit].values + times = df_0[time].values + n_cov = len(covariates) + + # Step A: Iteratively demean Y and all X columns to remove unit+time FE + y_dm = self._iterative_demean(y, units, times, df_0.index) + X_dm = np.column_stack( + [ + self._iterative_demean(X_raw[:, j], units, times, df_0.index) + for j in range(n_cov) + ] + ) + + # Step B: OLS for covariate coefficients on demeaned data + result = solve_ols( + X_dm, + y_dm, + return_vcov=False, + rank_deficient_action=self.rank_deficient_action, + column_names=covariates, + ) + delta_hat = result[0] + + # Mask of covariates with finite coefficients (before cleaning) + # Used to exclude rank-deficient covariates from variance design matrices + kept_cov_mask = np.isfinite(delta_hat) + + # Replace NaN coefficients with 0 for adjustment + # (rank-deficient covariates are dropped) + delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0) + + # Step C: Recover FE from covariate-adjusted outcome using iterative FE + y_adj = y - X_raw @ delta_hat_clean + unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) + + # grand_mean = 0: iterative FE absorb the intercept + return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask + + # ========================================================================= + # Step 2: Impute counterfactuals + # ========================================================================= + + def _impute_treatment_effects( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Step 2: Impute Y(0) for treated observations and compute tau_hat. + + Returns + ------- + tau_hat : np.ndarray + Imputed treatment effects for each treated observation. + y_hat_0 : np.ndarray + Imputed counterfactual Y(0). + """ + df_1 = df.loc[omega_1_mask] + n_1 = len(df_1) + + # Look up unit and time FE + alpha_i = df_1[unit].map(unit_fe).values + beta_t = df_1[time].map(time_fe).values + + # Handle missing FE (set to NaN) + alpha_i = np.where(pd.isna(alpha_i), np.nan, alpha_i).astype(float) + beta_t = np.where(pd.isna(beta_t), np.nan, beta_t).astype(float) + + y_hat_0 = grand_mean + alpha_i + beta_t + + if delta_hat is not None and covariates: + X_1 = df_1[covariates].values + y_hat_0 = y_hat_0 + X_1 @ delta_hat + + tau_hat = df_1[outcome].values - y_hat_0 + + return tau_hat, y_hat_0 + + # ========================================================================= + # Conservative Variance (Theorem 3) + # ========================================================================= + + def _compute_cluster_psi_sums( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + weights: np.ndarray, + cluster_var: str, + kept_cov_mask: Optional[np.ndarray] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute cluster-level influence function sums (Theorem 3). + + psi_i = sum_t v_it * epsilon_tilde_it, summed within each cluster. + + Returns + ------- + cluster_psi_sums : np.ndarray + Array of cluster-level psi sums. + cluster_ids_unique : np.ndarray + Unique cluster identifiers (matching order of psi sums). + """ + df_0 = df.loc[omega_0_mask] + df_1 = df.loc[omega_1_mask] + n_0 = len(df_0) + n_1 = len(df_1) + + # ---- Compute v_it for treated observations ---- + v_treated = weights.copy() + + # ---- Compute v_it for untreated observations ---- + if covariates is None or len(covariates) == 0: + # FE-only case: closed-form + treated_units = df_1[unit].values + treated_times = df_1[time].values + + w_by_unit: Dict[Any, float] = {} + for i_idx in range(n_1): + u = treated_units[i_idx] + w_by_unit[u] = w_by_unit.get(u, 0.0) + weights[i_idx] + + w_by_time: Dict[Any, float] = {} + for i_idx in range(n_1): + t = treated_times[i_idx] + w_by_time[t] = w_by_time.get(t, 0.0) + weights[i_idx] + + w_total = float(np.sum(weights)) + + n0_by_unit = df_0.groupby(unit).size().to_dict() + n0_by_time = df_0.groupby(time).size().to_dict() + + untreated_units = df_0[unit].values + untreated_times = df_0[time].values + v_untreated = np.zeros(n_0) + + for j in range(n_0): + u = untreated_units[j] + t = untreated_times[j] + w_i = w_by_unit.get(u, 0.0) + w_t = w_by_time.get(t, 0.0) + n0_i = n0_by_unit.get(u, 1) + n0_t = n0_by_time.get(t, 1) + v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n_0) + else: + v_untreated = self._compute_v_untreated_with_covariates( + df_0, + df_1, + unit, + time, + covariates, + weights, + delta_hat, + kept_cov_mask=kept_cov_mask, + ) + + # ---- Compute auxiliary model residuals (Equation 8) ---- + epsilon_treated = self._compute_auxiliary_residuals_treated( + df_1, + outcome, + unit, + time, + first_treat, + covariates, + unit_fe, + time_fe, + grand_mean, + delta_hat, + v_treated, + ) + epsilon_untreated = self._compute_residuals_untreated( + df_0, outcome, unit, time, covariates, unit_fe, time_fe, grand_mean, delta_hat + ) + + # ---- psi_it = v_it * epsilon_tilde_it ---- + v_all = np.empty(len(df)) + v_all[omega_1_mask.values] = v_treated + v_all[omega_0_mask.values] = v_untreated + + eps_all = np.empty(len(df)) + eps_all[omega_1_mask.values] = epsilon_treated + eps_all[omega_0_mask.values] = epsilon_untreated + + ve_product = v_all * eps_all + # NaN eps from missing FE (rank condition violation). Zero their variance + # contribution — matches R's did_imputation which drops unimputable obs. + np.nan_to_num(ve_product, copy=False, nan=0.0) + + # Sum within clusters + cluster_ids = df[cluster_var].values + ve_series = pd.Series(ve_product, index=df.index) + cluster_sums = ve_series.groupby(cluster_ids).sum() + + return cluster_sums.values, cluster_sums.index.values + + def _compute_conservative_variance( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + weights: np.ndarray, + cluster_var: str, + kept_cov_mask: Optional[np.ndarray] = None, + ) -> float: + """ + Compute conservative clustered variance (Theorem 3, Equation 7). + + Parameters + ---------- + weights : np.ndarray + Aggregation weights w_it for treated observations. + Shape: (n_treated,), must sum to 1. + + Returns + ------- + float + Standard error. + """ + cluster_psi_sums, _ = self._compute_cluster_psi_sums( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + weights=weights, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + sigma_sq = float((cluster_psi_sums**2).sum()) + return np.sqrt(max(sigma_sq, 0.0)) + + def _compute_v_untreated_with_covariates( + self, + df_0: pd.DataFrame, + df_1: pd.DataFrame, + unit: str, + time: str, + covariates: List[str], + weights: np.ndarray, + delta_hat: Optional[np.ndarray], + kept_cov_mask: Optional[np.ndarray] = None, + ) -> np.ndarray: + """ + Compute v_it for untreated observations with covariates. + + Uses the projection: v_untreated = -A_0 (A_0'A_0)^{-1} A_1' w_treated + + Uses scipy.sparse for FE dummy columns to reduce memory from O(N*(U+T)) + to O(N) for the FE portion. + """ + # Exclude rank-deficient covariates from design matrices + if kept_cov_mask is not None and not np.all(kept_cov_mask): + covariates = [c for c, k in zip(covariates, kept_cov_mask) if k] + + units_0 = df_0[unit].values + times_0 = df_0[time].values + units_1 = df_1[unit].values + times_1 = df_1[time].values + + all_units = np.unique(np.concatenate([units_0, units_1])) + all_times = np.unique(np.concatenate([times_0, times_1])) + unit_to_idx = {u: i for i, u in enumerate(all_units)} + time_to_idx = {t: i for i, t in enumerate(all_times)} + n_units = len(all_units) + n_times = len(all_times) + n_cov = len(covariates) + n_fe_cols = (n_units - 1) + (n_times - 1) + + def _build_A_sparse(df_sub, unit_vals, time_vals): + n = len(df_sub) + + # Unit dummies (drop first) — vectorized + u_indices = np.array([unit_to_idx[u] for u in unit_vals]) + u_mask = u_indices > 0 # skip first unit (dropped) + u_rows = np.arange(n)[u_mask] + u_cols = u_indices[u_mask] - 1 + + # Time dummies (drop first) — vectorized + t_indices = np.array([time_to_idx[t] for t in time_vals]) + t_mask = t_indices > 0 + t_rows = np.arange(n)[t_mask] + t_cols = (n_units - 1) + t_indices[t_mask] - 1 + + rows = np.concatenate([u_rows, t_rows]) + cols = np.concatenate([u_cols, t_cols]) + data = np.ones(len(rows)) + + A_fe = sparse.csr_matrix((data, (rows, cols)), shape=(n, n_fe_cols)) + + # Covariates (dense, typically few columns) + if n_cov > 0: + A_cov = sparse.csr_matrix(df_sub[covariates].values) + A = sparse.hstack([A_fe, A_cov], format="csr") + else: + A = A_fe + + return A + + A_0 = _build_A_sparse(df_0, units_0, times_0) + A_1 = _build_A_sparse(df_1, units_1, times_1) + + # Compute A_1' w (sparse.T @ dense -> dense) + A1_w = A_1.T @ weights # shape (p,) + + # Solve (A_0'A_0) z = A_1' w using sparse direct solver + A0tA0_sparse = A_0.T @ A_0 # stays sparse + try: + z = spsolve(A0tA0_sparse.tocsc(), A1_w) + except Exception: + # Fallback to dense lstsq if sparse solver fails (e.g., singular matrix) + A0tA0_dense = A0tA0_sparse.toarray() + z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None) + + # v_untreated = -A_0 z (sparse @ dense -> dense) + v_untreated = -(A_0 @ z) + return v_untreated + + def _compute_auxiliary_residuals_treated( + self, + df_1: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + v_treated: np.ndarray, + ) -> np.ndarray: + """ + Compute v_it-weighted auxiliary residuals for treated obs (Equation 8). + + Computes v_it-weighted tau_tilde_g per Equation 8 of Borusyak et al. (2024): + tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g. + + epsilon_tilde_it = Y_it - alpha_i - beta_t [- X'delta] - tau_tilde_g + """ + n_1 = len(df_1) + + # Compute base residuals (Y - Y_hat(0) = tau_hat) + # NaN for missing FE (consistent with _impute_treatment_effects) + alpha_i = df_1[unit].map(unit_fe).values.astype(float) # NaN for missing + beta_t = df_1[time].map(time_fe).values.astype(float) # NaN for missing + y_hat_0 = grand_mean + alpha_i + beta_t + + if delta_hat is not None and covariates: + y_hat_0 = y_hat_0 + df_1[covariates].values @ delta_hat + + tau_hat = df_1[outcome].values - y_hat_0 + + # Partition Omega_1 and compute tau_tilde for each group + if self.aux_partition == "cohort_horizon": + group_keys = list(zip(df_1[first_treat].values, df_1["_rel_time"].values)) + elif self.aux_partition == "cohort": + group_keys = list(df_1[first_treat].values) + elif self.aux_partition == "horizon": + group_keys = list(df_1["_rel_time"].values) + else: + group_keys = list(range(n_1)) # each obs is its own group + + # Compute v_it-weighted average tau within each partition group (Equation 8) + # tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it) within group g + group_series = pd.Series(group_keys, index=df_1.index) + tau_series = pd.Series(tau_hat, index=df_1.index) + v_series = pd.Series(v_treated, index=df_1.index) + + weighted_tau_sum = (v_series * tau_series).groupby(group_series).sum() + weight_sum = v_series.groupby(group_series).sum() + + # Guard: zero-weight groups -> their tau_tilde doesn't affect variance + # (v_it ~ 0 means these obs contribute nothing to the estimand) + # Use simple mean as fallback. This is common for event-study SE computation + # where weights target a specific horizon, making other partition groups zero. + zero_weight_groups = weight_sum.abs() < 1e-15 + if zero_weight_groups.any(): + simple_means = tau_series.groupby(group_series).mean() + tau_tilde_map = weighted_tau_sum / weight_sum + tau_tilde_map = tau_tilde_map.where(~zero_weight_groups, simple_means) + else: + tau_tilde_map = weighted_tau_sum / weight_sum + + tau_tilde = group_series.map(tau_tilde_map).values + + # Auxiliary residuals + epsilon_treated = tau_hat - tau_tilde + + return epsilon_treated + + def _compute_residuals_untreated( + self, + df_0: pd.DataFrame, + outcome: str, + unit: str, + time: str, + covariates: Optional[List[str]], + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + ) -> np.ndarray: + """Compute Step 1 residuals for untreated observations.""" + alpha_i = df_0[unit].map(unit_fe).fillna(0.0).values + beta_t = df_0[time].map(time_fe).fillna(0.0).values + y_hat = grand_mean + alpha_i + beta_t + + if delta_hat is not None and covariates: + y_hat = y_hat + df_0[covariates].values @ delta_hat + + return df_0[outcome].values - y_hat + + # ========================================================================= + # Aggregation + # ========================================================================= + + def _aggregate_event_study( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + treatment_groups: List[Any], + balance_e: Optional[int] = None, + kept_cov_mask: Optional[np.ndarray] = None, + ) -> Dict[int, Dict[str, Any]]: + """Aggregate treatment effects by event-study horizon.""" + df_1 = df.loc[omega_1_mask] + tau_hat = df["_tau_hat"].loc[omega_1_mask].values + rel_times = df_1["_rel_time"].values + + # Get all horizons + all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h))) + + # Apply horizon_max filter + if self.horizon_max is not None: + all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] + + # Apply balance_e filter + if balance_e is not None: + cohort_rel_times = self._build_cohort_rel_times(df, first_treat) + balanced_mask = pd.Series( + self._compute_balanced_cohort_mask( + df_1, first_treat, all_horizons, balance_e, cohort_rel_times + ), + index=df_1.index, + ) + else: + balanced_mask = pd.Series(True, index=df_1.index) + + # Check Proposition 5: no never-treated units + has_never_treated = df["_never_treated"].any() + h_bar = np.inf + if not has_never_treated and len(treatment_groups) > 1: + h_bar = max(treatment_groups) - min(treatment_groups) + + # Reference period + ref_period = -1 - self.anticipation + + event_study_effects: Dict[int, Dict[str, Any]] = {} + + # Add reference period marker + event_study_effects[ref_period] = { + "effect": 0.0, + "se": 0.0, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (0.0, 0.0), + "n_obs": 0, + } + + # Collect horizons with Proposition 5 violations + prop5_horizons = [] + + for h in all_horizons: + if h == ref_period: + continue + + # Select treated obs at this horizon from balanced cohorts + h_mask = (rel_times == h) & balanced_mask.values + n_h = int(h_mask.sum()) + + if n_h == 0: + continue + + # Proposition 5 check + if not has_never_treated and h >= h_bar: + prop5_horizons.append(h) + event_study_effects[h] = { + "effect": np.nan, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": n_h, + } + continue + + tau_h = tau_hat[h_mask] + valid_tau = tau_h[np.isfinite(tau_h)] + + if len(valid_tau) == 0: + event_study_effects[h] = { + "effect": np.nan, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": n_h, + } + continue + + effect = float(np.mean(valid_tau)) + + # Compute SE via conservative variance with horizon-specific weights + weights_h = np.zeros(int(omega_1_mask.sum())) + # Map h_mask (relative to df_1) to weights array + h_indices_in_omega1 = np.where(h_mask)[0] + n_valid = len(valid_tau) + # Only weight valid (finite) observations + finite_mask = np.isfinite(tau_hat[h_mask]) + valid_h_indices = h_indices_in_omega1[finite_mask] + for idx in valid_h_indices: + weights_h[idx] = 1.0 / n_valid + + se = self._compute_conservative_variance( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + weights=weights_h, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + + t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan + p_value = compute_p_value(t_stat) + conf_int = ( + compute_confidence_interval(effect, se, self.alpha) + if np.isfinite(se) and se > 0 + else (np.nan, np.nan) + ) + + event_study_effects[h] = { + "effect": effect, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + "n_obs": n_h, + } + + # Proposition 5 warning + if prop5_horizons: + warnings.warn( + f"Horizons {prop5_horizons} are not identified without " + f"never-treated units (Proposition 5). Set to NaN.", + UserWarning, + stacklevel=3, + ) + + # Check for empty result set after filtering + real_effects = [ + h for h, v in event_study_effects.items() if h != ref_period and v.get("n_obs", 0) > 0 + ] + if len(real_effects) == 0: + filter_info = [] + if balance_e is not None: + filter_info.append(f"balance_e={balance_e}") + if self.horizon_max is not None: + filter_info.append(f"horizon_max={self.horizon_max}") + filter_str = " and ".join(filter_info) if filter_info else "filters" + warnings.warn( + f"Event study aggregation produced no horizons with observations " + f"after applying {filter_str}. The result contains only the " + f"reference period marker. Consider relaxing filter parameters.", + UserWarning, + stacklevel=3, + ) + + return event_study_effects + + def _aggregate_group( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + treatment_groups: List[Any], + kept_cov_mask: Optional[np.ndarray] = None, + ) -> Dict[Any, Dict[str, Any]]: + """Aggregate treatment effects by cohort.""" + df_1 = df.loc[omega_1_mask] + tau_hat = df["_tau_hat"].loc[omega_1_mask].values + cohorts = df_1[first_treat].values + + group_effects: Dict[Any, Dict[str, Any]] = {} + + for g in treatment_groups: + g_mask = cohorts == g + n_g = int(g_mask.sum()) + + if n_g == 0: + continue + + tau_g = tau_hat[g_mask] + valid_tau = tau_g[np.isfinite(tau_g)] + + if len(valid_tau) == 0: + group_effects[g] = { + "effect": np.nan, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": n_g, + } + continue + + effect = float(np.mean(valid_tau)) + + # Compute SE with group-specific weights + weights_g = np.zeros(int(omega_1_mask.sum())) + finite_mask = np.isfinite(tau_hat) & g_mask + g_indices = np.where(finite_mask)[0] + n_valid = len(valid_tau) + for idx in g_indices: + weights_g[idx] = 1.0 / n_valid + + se = self._compute_conservative_variance( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + weights=weights_g, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + + t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan + p_value = compute_p_value(t_stat) + conf_int = ( + compute_confidence_interval(effect, se, self.alpha) + if np.isfinite(se) and se > 0 + else (np.nan, np.nan) + ) + + group_effects[g] = { + "effect": effect, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + "n_obs": n_g, + } + + return group_effects + + # ========================================================================= + # Pre-trend test (Equation 9) + # ========================================================================= + + def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]: + """ + Run pre-trend test (Equation 9). + + Adds pre-treatment lead indicators to the Step 1 OLS on Omega_0 + and tests their joint significance via cluster-robust Wald F-test. + """ + if self._fit_data is None: + raise RuntimeError("Must call fit() before pretrend_test().") + + fd = self._fit_data + df = fd["df"] + outcome = fd["outcome"] + unit = fd["unit"] + time = fd["time"] + first_treat = fd["first_treat"] + covariates = fd["covariates"] + omega_0_mask = fd["omega_0_mask"] + cluster_var = fd["cluster_var"] + + df_0 = df.loc[omega_0_mask].copy() + + # Compute relative time for untreated obs + # For not-yet-treated units in their pre-treatment periods + rel_time_0 = np.where( + ~df_0["_never_treated"], + df_0[time] - df_0[first_treat], + np.nan, + ) + + # Get available pre-treatment relative times (negative values) + pre_rel_times = sorted( + set(int(h) for h in rel_time_0 if np.isfinite(h) and h < -self.anticipation) + ) + + if len(pre_rel_times) == 0: + return { + "f_stat": np.nan, + "p_value": np.nan, + "df": 0, + "n_leads": 0, + "lead_coefficients": {}, + } + + # Exclude the reference period (last pre-treatment period) + ref = -1 - self.anticipation + pre_rel_times = [h for h in pre_rel_times if h != ref] + + if n_leads is not None: + # Take the n_leads periods closest to treatment + pre_rel_times = sorted(pre_rel_times, reverse=True)[:n_leads] + pre_rel_times = sorted(pre_rel_times) + + if len(pre_rel_times) == 0: + return { + "f_stat": np.nan, + "p_value": np.nan, + "df": 0, + "n_leads": 0, + "lead_coefficients": {}, + } + + # Build lead indicators + lead_cols = [] + for h in pre_rel_times: + col_name = f"_lead_{h}" + df_0[col_name] = ((rel_time_0 == h)).astype(float) + lead_cols.append(col_name) + + # Within-transform via iterative demeaning (exact for unbalanced panels) + y_dm = self._iterative_demean( + df_0[outcome].values, df_0[unit].values, df_0[time].values, df_0.index + ) + + all_x_cols = lead_cols[:] + if covariates: + all_x_cols.extend(covariates) + + X_dm = np.column_stack( + [ + self._iterative_demean( + df_0[col].values, df_0[unit].values, df_0[time].values, df_0.index + ) + for col in all_x_cols + ] + ) + + # OLS with cluster-robust SEs + cluster_ids = df_0[cluster_var].values + result = solve_ols( + X_dm, + y_dm, + cluster_ids=cluster_ids, + return_vcov=True, + rank_deficient_action=self.rank_deficient_action, + column_names=all_x_cols, + ) + coefficients = result[0] + vcov = result[2] + + # Extract lead coefficients and their sub-VCV + n_leads_actual = len(lead_cols) + gamma = coefficients[:n_leads_actual] + V_gamma = vcov[:n_leads_actual, :n_leads_actual] + + # Wald F-test: F = (gamma' V^{-1} gamma) / n_leads + try: + V_inv_gamma = np.linalg.solve(V_gamma, gamma) + wald_stat = float(gamma @ V_inv_gamma) + f_stat = wald_stat / n_leads_actual + except np.linalg.LinAlgError: + f_stat = np.nan + + # P-value from F distribution + if np.isfinite(f_stat) and f_stat >= 0: + n_clusters = len(np.unique(cluster_ids)) + df_denom = max(n_clusters - 1, 1) + p_value = float(stats.f.sf(f_stat, n_leads_actual, df_denom)) + else: + p_value = np.nan + + # Store lead coefficients + lead_coefficients = {} + for j, h in enumerate(pre_rel_times): + lead_coefficients[h] = float(gamma[j]) + + return { + "f_stat": f_stat, + "p_value": p_value, + "df": n_leads_actual, + "n_leads": n_leads_actual, + "lead_coefficients": lead_coefficients, + } + + # ========================================================================= + # Bootstrap + # ========================================================================= + + def _compute_percentile_ci( + self, + boot_dist: np.ndarray, + alpha: float, + ) -> Tuple[float, float]: + """Compute percentile confidence interval from bootstrap distribution.""" + lower = float(np.percentile(boot_dist, alpha / 2 * 100)) + upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) + return (lower, upper) + + def _compute_bootstrap_pvalue( + self, + original_effect: float, + boot_dist: np.ndarray, + n_valid: Optional[int] = None, + ) -> float: + """ + Compute two-sided bootstrap p-value. + + Uses the percentile method: p-value is the proportion of bootstrap + estimates on the opposite side of zero from the original estimate, + doubled for two-sided test. + + Parameters + ---------- + original_effect : float + Original point estimate. + boot_dist : np.ndarray + Bootstrap distribution of the effect. + n_valid : int, optional + Number of valid bootstrap samples. If None, uses self.n_bootstrap. + """ + if original_effect >= 0: + p_one_sided = float(np.mean(boot_dist <= 0)) + else: + p_one_sided = float(np.mean(boot_dist >= 0)) + p_value = min(2 * p_one_sided, 1.0) + n_for_floor = n_valid if n_valid is not None else self.n_bootstrap + p_value = max(p_value, 1 / (n_for_floor + 1)) + return p_value + + def _precompute_bootstrap_psi( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + kept_cov_mask: Optional[np.ndarray], + overall_weights: np.ndarray, + event_study_effects: Optional[Dict[int, Dict[str, Any]]], + group_effects: Optional[Dict[Any, Dict[str, Any]]], + treatment_groups: List[Any], + tau_hat: np.ndarray, + balance_e: Optional[int], + ) -> Dict[str, Any]: + """ + Pre-compute cluster-level influence function sums for each bootstrap target. + + For each aggregation target (overall, per-horizon, per-group), computes + psi_i = sum_t v_it * epsilon_tilde_it for each cluster. The multiplier + bootstrap then perturbs these psi sums with Rademacher weights. + + Computational cost scales with the number of aggregation targets, since + each target requires its own v_untreated computation (weight-dependent). + """ + result: Dict[str, Any] = {} + + common = dict( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + + # Overall ATT + overall_psi, cluster_ids = self._compute_cluster_psi_sums(**common, weights=overall_weights) + result["overall"] = (overall_psi, cluster_ids) + + # Event study: per-horizon weights + # NOTE: weight logic duplicated from _aggregate_event_study. + # If weight scheme changes there, update here too. + if event_study_effects: + result["event_study"] = {} + df_1 = df.loc[omega_1_mask] + rel_times = df_1["_rel_time"].values + n_omega_1 = int(omega_1_mask.sum()) + + # Balanced cohort mask (same logic as _aggregate_event_study) + balanced_mask = None + if balance_e is not None: + all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h))) + if self.horizon_max is not None: + all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] + cohort_rel_times = self._build_cohort_rel_times(df, first_treat) + balanced_mask = self._compute_balanced_cohort_mask( + df_1, first_treat, all_horizons, balance_e, cohort_rel_times + ) + + ref_period = -1 - self.anticipation + for h in event_study_effects: + if event_study_effects[h].get("n_obs", 0) == 0: + continue + if h == ref_period: + continue + if not np.isfinite(event_study_effects[h].get("effect", np.nan)): + continue + h_mask = rel_times == h + if balanced_mask is not None: + h_mask = h_mask & balanced_mask + weights_h = np.zeros(n_omega_1) + finite_h = np.isfinite(tau_hat) & h_mask + n_valid_h = int(finite_h.sum()) + if n_valid_h == 0: + continue + weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h + + psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h) + result["event_study"][h] = psi_h + + # Group effects: per-group weights + # NOTE: weight logic duplicated from _aggregate_group. + # If weight scheme changes there, update here too. + if group_effects: + result["group"] = {} + df_1 = df.loc[omega_1_mask] + cohorts = df_1[first_treat].values + n_omega_1 = int(omega_1_mask.sum()) + + for g in group_effects: + if group_effects[g].get("n_obs", 0) == 0: + continue + if not np.isfinite(group_effects[g].get("effect", np.nan)): + continue + g_mask = cohorts == g + weights_g = np.zeros(n_omega_1) + finite_g = np.isfinite(tau_hat) & g_mask + n_valid_g = int(finite_g.sum()) + if n_valid_g == 0: + continue + weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g + + psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g) + result["group"][g] = psi_g + + return result + + def _run_bootstrap( + self, + original_att: float, + original_event_study: Optional[Dict[int, Dict[str, Any]]], + original_group: Optional[Dict[Any, Dict[str, Any]]], + psi_data: Dict[str, Any], + ) -> ImputationBootstrapResults: + """ + Run multiplier bootstrap on pre-computed influence function sums. + + Uses T_b = sum_i w_b_i * psi_i where w_b_i are Rademacher weights + and psi_i are cluster-level influence function sums from Theorem 3. + SE = std(T_b, ddof=1). + """ + if self.n_bootstrap < 50: + warnings.warn( + f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " + "for reliable inference.", + UserWarning, + stacklevel=3, + ) + + rng = np.random.default_rng(self.seed) + + from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch + + overall_psi, cluster_ids = psi_data["overall"] + n_clusters = len(cluster_ids) + + # Generate ALL weights upfront: shape (n_bootstrap, n_clusters) + all_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_clusters, "rademacher", rng + ) + + # Overall ATT bootstrap draws + boot_overall = all_weights @ overall_psi # (n_bootstrap,) + + # Event study: loop over horizons + boot_event_study: Optional[Dict[int, np.ndarray]] = None + if original_event_study and "event_study" in psi_data: + boot_event_study = {} + for h, psi_h in psi_data["event_study"].items(): + boot_event_study[h] = all_weights @ psi_h + + # Group effects: loop over groups + boot_group: Optional[Dict[Any, np.ndarray]] = None + if original_group and "group" in psi_data: + boot_group = {} + for g, psi_g in psi_data["group"].items(): + boot_group[g] = all_weights @ psi_g + + # --- Inference (percentile bootstrap, matching CS/SA convention) --- + # Shift perturbation-centered draws to effect-centered draws. + # The multiplier bootstrap produces T_b = sum w_b_i * psi_i centered at 0. + # CS adds the original effect back (L411 of staggered_bootstrap.py). + # We do the same here so percentile CIs and empirical p-values work correctly. + boot_overall_shifted = boot_overall + original_att + + overall_se = float(np.std(boot_overall, ddof=1)) + overall_ci = ( + self._compute_percentile_ci(boot_overall_shifted, self.alpha) + if overall_se > 0 + else (np.nan, np.nan) + ) + overall_p = ( + self._compute_bootstrap_pvalue(original_att, boot_overall_shifted) + if overall_se > 0 + else np.nan + ) + + event_study_ses = None + event_study_cis = None + event_study_p_values = None + if boot_event_study and original_event_study: + event_study_ses = {} + event_study_cis = {} + event_study_p_values = {} + for h in boot_event_study: + se_h = float(np.std(boot_event_study[h], ddof=1)) + event_study_ses[h] = se_h + orig_eff = original_event_study[h]["effect"] + if se_h > 0 and np.isfinite(orig_eff): + shifted_h = boot_event_study[h] + orig_eff + event_study_p_values[h] = self._compute_bootstrap_pvalue(orig_eff, shifted_h) + event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha) + else: + event_study_p_values[h] = np.nan + event_study_cis[h] = (np.nan, np.nan) + + group_ses = None + group_cis = None + group_p_values = None + if boot_group and original_group: + group_ses = {} + group_cis = {} + group_p_values = {} + for g in boot_group: + se_g = float(np.std(boot_group[g], ddof=1)) + group_ses[g] = se_g + orig_eff = original_group[g]["effect"] + if se_g > 0 and np.isfinite(orig_eff): + shifted_g = boot_group[g] + orig_eff + group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g) + group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha) + else: + group_p_values[g] = np.nan + group_cis[g] = (np.nan, np.nan) + + return ImputationBootstrapResults( + n_bootstrap=self.n_bootstrap, + weight_type="rademacher", + alpha=self.alpha, + overall_att_se=overall_se, + overall_att_ci=overall_ci, + overall_att_p_value=overall_p, + event_study_ses=event_study_ses, + event_study_cis=event_study_cis, + event_study_p_values=event_study_p_values, + group_ses=group_ses, + group_cis=group_cis, + group_p_values=group_p_values, + bootstrap_distribution=boot_overall_shifted, + ) + + # ========================================================================= + # sklearn-compatible interface + # ========================================================================= + + def get_params(self) -> Dict[str, Any]: + """Get estimator parameters (sklearn-compatible).""" + return { + "anticipation": self.anticipation, + "alpha": self.alpha, + "cluster": self.cluster, + "n_bootstrap": self.n_bootstrap, + "seed": self.seed, + "rank_deficient_action": self.rank_deficient_action, + "horizon_max": self.horizon_max, + "aux_partition": self.aux_partition, + } + + def set_params(self, **params) -> "ImputationDiD": + """Set estimator parameters (sklearn-compatible).""" + for key, value in params.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise ValueError(f"Unknown parameter: {key}") + return self + + def summary(self) -> str: + """Get summary of estimation results.""" + if not self.is_fitted_: + raise RuntimeError("Model must be fitted before calling summary()") + assert self.results_ is not None + return self.results_.summary() + + def print_summary(self) -> None: + """Print summary to stdout.""" + print(self.summary()) + + +# ============================================================================= +# Convenience function +# ============================================================================= + + +def imputation_did( + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]] = None, + aggregate: Optional[str] = None, + balance_e: Optional[int] = None, + **kwargs, +) -> ImputationDiDResults: + """ + Convenience function for imputation DiD estimation. + + This is a shortcut for creating an ImputationDiD estimator and calling fit(). + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Outcome variable column name. + unit : str + Unit identifier column name. + time : str + Time period column name. + first_treat : str + Column indicating first treatment period (0 for never-treated). + covariates : list of str, optional + Covariate column names. + aggregate : str, optional + Aggregation mode: None, "simple", "event_study", "group", "all". + balance_e : int, optional + Balance event study to cohorts observed at all relative times. + **kwargs + Additional keyword arguments passed to ImputationDiD constructor. + + Returns + ------- + ImputationDiDResults + Estimation results. + + Examples + -------- + >>> from diff_diff import imputation_did, generate_staggered_data + >>> data = generate_staggered_data(seed=42) + >>> results = imputation_did(data, 'outcome', 'unit', 'time', 'first_treat', + ... aggregate='event_study') + >>> results.print_summary() + """ + est = ImputationDiD(**kwargs) + return est.fit( + data, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + aggregate=aggregate, + balance_e=balance_e, + ) diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index df9c301..6903d41 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -16,6 +16,7 @@ from diff_diff.pretrends import PreTrendsPowerCurve, PreTrendsPowerResults from diff_diff.results import MultiPeriodDiDResults from diff_diff.staggered import CallawaySantAnnaResults + from diff_diff.imputation import ImputationDiDResults from diff_diff.sun_abraham import SunAbrahamResults # Type alias for results that can be plotted @@ -23,6 +24,7 @@ "MultiPeriodDiDResults", "CallawaySantAnnaResults", "SunAbrahamResults", + "ImputationDiDResults", pd.DataFrame, ] @@ -420,7 +422,7 @@ def _extract_plot_data( # Detect reference period from n_groups=0 marker (normalization constraint) # This handles anticipation > 0 where reference is at e = -1 - anticipation for period, effect_data in results.event_study_effects.items(): - if effect_data.get("n_groups", 1) == 0: + if effect_data.get("n_groups", 1) == 0 or effect_data.get("n_obs", 1) == 0: reference_period = period break # Fallback to -1 if no marker found (backward compatibility) @@ -438,7 +440,7 @@ def _extract_plot_data( raise TypeError( f"Cannot extract plot data from {type(results).__name__}. " "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, " - "SunAbrahamResults, or DataFrame." + "SunAbrahamResults, ImputationDiDResults, or DataFrame." ) diff --git a/docs/api/imputation.rst b/docs/api/imputation.rst new file mode 100644 index 0000000..fc91ca3 --- /dev/null +++ b/docs/api/imputation.rst @@ -0,0 +1,127 @@ +Imputation DiD (Borusyak et al. 2024) +======================================= + +Efficient imputation estimator for staggered Difference-in-Differences. + +This module implements the methodology from Borusyak, Jaravel & Spiess (2024), +"Revisiting Event-Study Designs: Robust and Efficient Estimation", +*Review of Economic Studies*. + +The estimator: + +1. Runs OLS on untreated observations to estimate unit + time fixed effects +2. Imputes counterfactual Y(0) for treated observations +3. Aggregates imputed treatment effects with researcher-chosen weights + +Inference uses the conservative clustered variance estimator from Theorem 3. + +**When to use ImputationDiD:** + +- Staggered adoption settings where treatment effects may be **homogeneous** + across cohorts and time — produces ~50% shorter CIs than Callaway-Sant'Anna +- When you want to use **all untreated observations** (never-treated + + not-yet-treated) for maximum efficiency +- As a complement to Callaway-Sant'Anna or Sun-Abraham: if all three agree, + results are robust; if they disagree, investigate heterogeneity + +**Reference:** Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting +Event-Study Designs: Robust and Efficient Estimation. *Review of Economic +Studies*, 91(6), 3253-3285. + +.. module:: diff_diff.imputation + +ImputationDiD +------------- + +Main estimator class for imputation DiD estimation. + +.. autoclass:: diff_diff.ImputationDiD + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + + .. rubric:: Methods + + .. autosummary:: + + ~ImputationDiD.fit + ~ImputationDiD.get_params + ~ImputationDiD.set_params + +ImputationDiDResults +-------------------- + +Results container for imputation DiD estimation. + +.. autoclass:: diff_diff.ImputationDiDResults + :members: + :undoc-members: + :show-inheritance: + + .. rubric:: Methods + + .. autosummary:: + + ~ImputationDiDResults.summary + ~ImputationDiDResults.print_summary + ~ImputationDiDResults.to_dataframe + ~ImputationDiDResults.pretrend_test + +ImputationBootstrapResults +-------------------------- + +Bootstrap inference results. + +.. autoclass:: diff_diff.ImputationBootstrapResults + :members: + :undoc-members: + :show-inheritance: + +Convenience Function +-------------------- + +.. autofunction:: diff_diff.imputation_did + +Example Usage +------------- + +Basic usage:: + + from diff_diff import ImputationDiD, generate_staggered_data + + data = generate_staggered_data(n_units=200, seed=42) + est = ImputationDiD() + results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat') + results.print_summary() + +Event study with visualization:: + + from diff_diff import ImputationDiD, plot_event_study + + est = ImputationDiD() + results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat', + aggregate='event_study') + plot_event_study(results) + +Pre-trend test:: + + results = est.fit(data, outcome='outcome', unit='unit', + time='period', first_treat='first_treat') + pt = results.pretrend_test(n_leads=3) + print(f"F-stat: {pt['f_stat']:.3f}, p-value: {pt['p_value']:.4f}") + +Comparison with other estimators:: + + from diff_diff import ImputationDiD, CallawaySantAnna, SunAbraham + + # All three should agree under homogeneous effects + imp = ImputationDiD().fit(data, ...) + cs = CallawaySantAnna().fit(data, ...) + sa = SunAbraham().fit(data, ...) + + print(f"Imputation ATT: {imp.overall_att:.3f} (SE: {imp.overall_se:.3f})") + print(f"CS ATT: {cs.overall_att:.3f} (SE: {cs.overall_se:.3f})") + print(f"SA ATT: {sa.overall_att:.3f} (SE: {sa.overall_se:.3f})") diff --git a/docs/api/index.rst b/docs/api/index.rst index 9f1c3a3..c87a464 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -18,6 +18,7 @@ Core estimator classes for DiD analysis: diff_diff.SyntheticDiD diff_diff.CallawaySantAnna diff_diff.SunAbraham + diff_diff.ImputationDiD diff_diff.TripleDifference diff_diff.TROP @@ -39,6 +40,8 @@ Result containers returned by estimators: diff_diff.GroupTimeEffect diff_diff.SunAbrahamResults diff_diff.SABootstrapResults + diff_diff.ImputationDiDResults + diff_diff.ImputationBootstrapResults diff_diff.TripleDifferenceResults diff_diff.trop.TROPResults @@ -181,6 +184,7 @@ Detailed documentation by module: estimators staggered + imputation triple_diff trop results diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 97897ea..52f5e6a 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -11,6 +11,7 @@ This document provides the academic foundations and key implementation requireme 2. [Modern Staggered Estimators](#modern-staggered-estimators) - [CallawaySantAnna](#callawaysantanna) - [SunAbraham](#sunabraham) + - [ImputationDiD](#imputationdid) 3. [Advanced Estimators](#advanced-estimators) - [SyntheticDiD](#syntheticdid) - [TripleDifference](#tripledifference) @@ -447,6 +448,111 @@ where weights ŵ_{g,e} = n_{g,e} / Σ_g n_{g,e} (sample share of cohort g at eve --- +## ImputationDiD + +**Primary source:** [Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event-Study Designs: Robust and Efficient Estimation. *Review of Economic Studies*, 91(6), 3253-3285.](https://doi.org/10.1093/restud/rdae007) + +**Key implementation requirements:** + +*Assumption checks / warnings:* +- **Parallel trends (Assumption 1):** `E[Y_it(0)] = alpha_i + beta_t` for all observations. General form allows `E[Y_it(0)] = alpha_i + beta_t + X'_it * delta` with time-varying covariates. +- **No-anticipation effects (Assumption 2):** `Y_it = Y_it(0)` for all untreated observations. Adjustable via `anticipation` parameter. +- Treatment must be absorbing: `D_it` switches from 0 to 1 and stays at 1. +- Covariate space of treated observations must be spanned by untreated observations (rank condition). For unit/period FE case: every treated unit must have ≥1 untreated period; every post-treatment period must have ≥1 untreated unit. +- Without never-treated units, long-run effects at horizon `K_it >= H_bar` (where `H_bar = max(first_treat) - min(first_treat)`) are not identified (Proposition 5). Set to NaN with warning. + +*Estimator equation (Theorem 2, as implemented):* + +``` +Step 1. Estimate counterfactual model on untreated observations only (it in Omega_0): + Y_it = alpha_i + beta_t [+ X'_it * delta] + epsilon_it + +Step 2. For each treated observation (it in Omega_1), impute: + Y_hat_it(0) = alpha_hat_i + beta_hat_t [+ X'_it * delta_hat] + tau_hat_it = Y_it - Y_hat_it(0) + +Step 3. Aggregate: + tau_hat_w = sum_{it in Omega_1} w_it * tau_hat_it +``` + +where: +- `Omega_0 = {it : D_it = 0}` — all untreated observations (never-treated + not-yet-treated) +- `Omega_1 = {it : D_it = 1}` — all treated observations +- `w_it` = pre-specified weights (overall ATT: `w_it = 1/N_1`) + +*Common estimation targets (weighting schemes):* +- Overall ATT: `w_it = 1/N_1` for all `it in Omega_1` +- Horizon-specific: `w_it = 1[K_it = h] / |Omega_{1,h}|` for `K_it = t - E_i` +- Group-specific: `w_it = 1[G_i = g] / |Omega_{1,g}|` + +*Standard errors (Theorem 3, Equation 7):* + +Conservative clustered variance estimator: +``` +sigma_hat^2_w = sum_i ( sum_{t: it in Omega} v_it * epsilon_tilde_it )^2 +``` + +Observation weights `v_it`: +- For treated `(i,t) in Omega_1`: `v_it = w_it` (the aggregation weight) +- For untreated `(i,t) in Omega_0` (FE-only case): `v_it = -(w_i./n_{0,i} + w_.t/n_{0,t} - w../N_0)` + where `w_i. = sum of w over treated obs of unit i`, `n_{0,i} = untreated periods for unit i`, etc. +- For untreated with covariates: `v_untreated = -A_0 (A_0' A_0)^{-1} A_1' w_treated` + where `A_0`, `A_1` are design matrices for untreated/treated observations. + +**Note on v_it derivation:** The paper's Supplementary Proposition A3 provides the explicit formula for `v_it^*`, but was not in the extraction range for the paper review. The FE-only closed form above is reconstructed from Theorem 3's general form — it follows from the chain rule of the imputation estimator's dependence on the Step 1 OLS estimates. The covariate case uses the OLS projection matrix directly. + +Auxiliary model residuals (Equation 8): +- Partition `Omega_1` into groups `G_g` (default: cohort × horizon) +- Compute `tau_tilde_g` for each group (weighted average within group) +- `epsilon_tilde_it = Y_it - alpha_hat_i - beta_hat_t [- X'delta_hat] - tau_tilde_g` (treated) +- `epsilon_tilde_it = Y_it - alpha_hat_i - beta_hat_t [- X'delta_hat]` (untreated, i.e., Step 1 residuals) + +The `aux_partition` parameter controls the partition: `"cohort_horizon"` (default, tightest SEs), `"cohort"` (coarser, more conservative), `"horizon"` (groups by relative time only). + +*Pre-trend test (Test 1, Equation 9):* +``` +Y_it = alpha_i + beta_t [+ X'_it * delta] + W'_it * gamma + epsilon_it +``` +- Estimate on untreated observations only +- Test `gamma = 0` via cluster-robust Wald F-test +- Independent of treatment effect estimation (Proposition 9) + +*Edge cases:* +- **Unbalanced panels:** FE estimated via iterative alternating projection (Gauss-Seidel), equivalent to OLS with unit+time dummies. Converges in O(max_iter) passes; typically 5-20 iterations for unbalanced panels, 1-2 for balanced. One-pass demeaning is only exact for balanced panels. +- **No never-treated units (Proposition 5):** Long-run effects at horizons `h >= H_bar` are not identified. Set to NaN with warning listing affected horizons. +- **Rank condition failure:** Every treated unit must have ≥1 untreated period; every post-treatment period must have ≥1 untreated unit. Behavior controlled by `rank_deficient_action`: "warn" (default), "error", or "silent". Missing FE produce NaN treatment effects for affected observations. +- **Always-treated units:** Units with `first_treat` at or before the earliest time period have no untreated observations. Warning emitted; these units are excluded from Step 1 OLS but their treated observations contribute to aggregation if imputation is possible. +- **NaN propagation:** If all `tau_hat` values for a given horizon or group are NaN, the aggregated effect and all inference fields (SE, t-stat, p-value, CI) are set to NaN. NaN in v*eps product (from missing FE) is zeroed for variance computation (matching R's did_imputation which drops unimputable obs). +- **NaN inference for undefined statistics:** t_stat uses NaN when SE is non-finite or zero; p_value and CI also NaN. Matches CallawaySantAnna NaN convention. +- **Pre-trend test:** Uses iterative demeaning (same as Step 1 FE) for exact within-transformation on unbalanced panels. One-pass demeaning is only exact for balanced panels. +- **Overall ATT variance:** Weights zero out non-finite tau_hat and renormalize, matching the ATT estimand (which averages only finite tau_hat). `_compute_conservative_variance` returns 0.0 for all-zeros weights, so the n_valid==0 guard is necessary to return NaN SE. +- **`balance_e` cohort filtering:** When `balance_e` is set, cohort balance is checked against the *full panel* (pre + post treatment) via `_build_cohort_rel_times()`, requiring observations at every relative time in `[-balance_e, max_h]`. Both analytical aggregation and bootstrap inference use the same `_compute_balanced_cohort_mask` with pre-computed cohort horizons. +- **Bootstrap clustering:** Multiplier bootstrap generates weights at `cluster_var` granularity (defaults to `unit` if `cluster` not specified). Invalid cluster column raises ValueError. +- **Non-constant `first_treat` within a unit:** Emits `UserWarning` identifying the count and example unit. The estimator proceeds using the first observed value per unit (via `.first()` aggregation), but results may be unreliable. +- **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand. +- **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection. +- **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails. +- **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with Rademacher weights. This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns. +- **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean. + +**Reference implementation(s):** +- Stata: `did_imputation` (Borusyak, Jaravel, Spiess; available from SSC) +- R: `didimputation` package (Kyle Butts) + +**Requirements checklist:** +- [x] Step 1: OLS on untreated observations only (never-treated + not-yet-treated) +- [x] Step 2: Impute counterfactual `Y_hat_it(0)` for treated observations +- [x] Step 3: Aggregate with researcher-chosen weights `w_it` +- [x] Conservative clustered variance estimator (Theorem 3, Equation 7) +- [x] Auxiliary model for treated residuals (Equation 8) with configurable partition (`aux_partition`) +- [x] Supports unit FE, period FE, and time-varying covariates +- [x] Refuses to estimate unidentified estimands (Proposition 5) — sets NaN with warning +- [x] Pre-trend test uses only untreated observations (Test 1, Equation 9) +- [x] Supports balanced and unbalanced panels (iterative Gauss-Seidel demeaning for exact FE) +- [x] Event study and group aggregation + +--- + # Advanced Estimators ## SyntheticDiD @@ -1075,6 +1181,7 @@ should be a deliberate user choice. | TwoWayFixedEffects | Cluster at unit | Wild bootstrap | | CallawaySantAnna | Analytical (influence fn) | Multiplier bootstrap | | SunAbraham | Cluster-robust + delta method | Pairs bootstrap | +| ImputationDiD | Conservative clustered (Thm 3) | Multiplier bootstrap (library extension; percentile CIs and empirical p-values, consistent with CS/SA) | | SyntheticDiD | Placebo variance (Alg 4) | Block bootstrap | | TripleDifference | HC1 / cluster-robust | Influence function for IPW/DR | | TROP | Block bootstrap | — | @@ -1094,6 +1201,7 @@ should be a deliberate user choice. | TwoWayFixedEffects | fixest | `feols(y ~ treat \| unit + time, ...)` | | CallawaySantAnna | did | `att_gt()` | | SunAbraham | fixest | `sunab()` | +| ImputationDiD | didimputation | `did_imputation()` | | SyntheticDiD | synthdid | `synthdid_estimate()` | | TripleDifference | - | (forthcoming) | | TROP | - | (forthcoming) | diff --git a/docs/tutorials/11_imputation_did.ipynb b/docs/tutorials/11_imputation_did.ipynb new file mode 100644 index 0000000..c953d46 --- /dev/null +++ b/docs/tutorials/11_imputation_did.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imputation DiD (Borusyak, Jaravel & Spiess 2024)\n", + "\n", + "This tutorial demonstrates the `ImputationDiD` estimator, which implements the efficient imputation approach from Borusyak, Jaravel & Spiess (2024), \"Revisiting Event-Study Designs: Robust and Efficient Estimation\", *Review of Economic Studies*.\n", + "\n", + "**When to use ImputationDiD:**\n", + "- Staggered adoption settings where treatment effects may be **homogeneous** across cohorts and time \u2014 produces ~50% shorter CIs than Callaway-Sant'Anna\n", + "- When you want to use **all untreated observations** (never-treated + not-yet-treated) for maximum efficiency\n", + "- As a complement to Callaway-Sant'Anna or Sun-Abraham: if all three agree, results are robust; if they disagree, investigate heterogeneity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from diff_diff import (\n", + " ImputationDiD, CallawaySantAnna, SunAbraham,\n", + " generate_staggered_data, plot_event_study\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "The imputation estimator follows a simple three-step process:\n", + "1. Estimate unit and time fixed effects using only untreated observations\n", + "2. Impute counterfactual Y(0) for treated observations\n", + "3. Aggregate imputed treatment effects with researcher-chosen weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate staggered adoption data with known treatment effect\n", + "data = generate_staggered_data(n_units=300, n_periods=10, treatment_effect=2.0, seed=42)\n", + "\n", + "# Fit the imputation estimator\n", + "est = ImputationDiD()\n", + "results = est.fit(data, outcome='outcome', unit='unit', time='period', first_treat='first_treat')\n", + "results.print_summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Event Study\n", + "\n", + "Event study aggregation estimates treatment effects at each relative time horizon, enabling visualization of dynamic treatment effects and pre-trend assessment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fit with event study aggregation\n", + "est = ImputationDiD()\n", + "results_es = est.fit(data, outcome='outcome', unit='unit', time='period',\n", + " first_treat='first_treat', aggregate='event_study')\n", + "\n", + "# Plot event study\n", + "plot_event_study(results_es, title='Imputation DiD Event Study')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# View event study effects as a table\n", + "results_es.to_dataframe(level='event_study')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-Trend Test\n", + "\n", + "The imputation estimator includes a built-in pre-trend test (Equation 9 in the paper). It tests whether pre-treatment leads are jointly zero using a Wald F-test on untreated observations only.\n", + "\n", + "A key advantage: the pre-trend test is **independent** of the treatment effect estimator (Proposition 9), avoiding the pre-testing problem identified by Roth (2022)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run pre-trend test\n", + "pt = results.pretrend_test(n_leads=3)\n", + "print(f\"F-statistic: {pt['f_stat']:.3f}\")\n", + "print(f\"P-value: {pt['p_value']:.4f}\")\n", + "print(f\"Leads tested: {pt['n_leads']}\")\n", + "print(f\"\\nConclusion: {'Fail to reject' if pt['p_value'] > 0.05 else 'Reject'} parallel trends at 5% level\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparison with Other Estimators\n", + "\n", + "Under homogeneous treatment effects, ImputationDiD, Callaway-Sant'Anna, and Sun-Abraham should produce similar point estimates. The key difference is efficiency \u2014 ImputationDiD produces shorter confidence intervals." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fit all three estimators on the same data\n", + "imp = ImputationDiD().fit(data, outcome='outcome', unit='unit',\n", + " time='period', first_treat='first_treat')\n", + "cs = CallawaySantAnna().fit(data, outcome='outcome', unit='unit',\n", + " time='period', first_treat='first_treat')\n", + "sa = SunAbraham().fit(data, outcome='outcome', unit='unit',\n", + " time='period', first_treat='first_treat')\n", + "\n", + "print(\"Estimator Comparison (True effect = 2.0)\")\n", + "print(\"=\" * 55)\n", + "print(f\"{'Estimator':<25} {'ATT':>8} {'SE':>8} {'CI Width':>10}\")\n", + "print(\"-\" * 55)\n", + "\n", + "for name, r in [(\"ImputationDiD\", imp), (\"CallawaySantAnna\", cs), (\"SunAbraham\", sa)]:\n", + " ci_width = r.overall_conf_int[1] - r.overall_conf_int[0]\n", + " print(f\"{name:<25} {r.overall_att:>8.3f} {r.overall_se:>8.3f} {ci_width:>10.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Group Aggregation\n", + "\n", + "Group aggregation estimates average treatment effects by treatment cohort (groups defined by first treatment period)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fit with group aggregation\n", + "results_grp = ImputationDiD().fit(data, outcome='outcome', unit='unit',\n", + " time='period', first_treat='first_treat',\n", + " aggregate='group')\n", + "results_grp.to_dataframe(level='group')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced Features\n", + "\n", + "### Anticipation\n", + "\n", + "If treatment effects begin before the official treatment date, use the `anticipation` parameter to account for this." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Account for 1 period of anticipation\n", + "est_antic = ImputationDiD(anticipation=1)\n", + "results_antic = est_antic.fit(data, outcome='outcome', unit='unit',\n", + " time='period', first_treat='first_treat')\n", + "print(f\"ATT (no anticipation): {results.overall_att:.3f}\")\n", + "print(f\"ATT (1-period anticipation): {results_antic.overall_att:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Auxiliary Model Partition\n", + "\n", + "The `aux_partition` parameter controls the auxiliary model partition for the conservative variance estimator (Theorem 3). Finer partitions give tighter SEs but may overfit with few observations per group." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compare different partition choices\n", + "for partition in ['cohort_horizon', 'cohort', 'horizon']:\n", + " r = ImputationDiD(aux_partition=partition).fit(\n", + " data, outcome='outcome', unit='unit',\n", + " time='period', first_treat='first_treat')\n", + " print(f\"aux_partition='{partition}': ATT={r.overall_att:.3f}, SE={r.overall_se:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| Feature | ImputationDiD | CallawaySantAnna | SunAbraham |\n", + "|---------|--------------|------------------|------------|\n", + "| **Approach** | Impute Y(0) via FE model | Group-time ATT(g,t) | Saturated regression |\n", + "| **Efficiency** | Most efficient under homogeneity | Less efficient | Least efficient |\n", + "| **Robustness** | Requires homogeneity for efficiency | Fully robust to heterogeneity | Robust to heterogeneity |\n", + "| **Control group** | All untreated (always) | Never-treated or not-yet-treated | Never-treated |\n", + "| **Best for** | Homogeneous effects, maximum power | Heterogeneous effects, flexible | Robustness check |\n", + "\n", + "**Reference:** Borusyak, K., Jaravel, X., & Spiess, J. (2024). Revisiting Event-Study Designs: Robust and Efficient Estimation. *Review of Economic Studies*, 91(6), 3253-3285." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/test_imputation.py b/tests/test_imputation.py new file mode 100644 index 0000000..029af74 --- /dev/null +++ b/tests/test_imputation.py @@ -0,0 +1,1954 @@ +""" +Tests for Borusyak-Jaravel-Spiess (2024) imputation DiD estimator. +""" + +import warnings + +import numpy as np +import pandas as pd +import pytest + +from diff_diff.imputation import ( + ImputationBootstrapResults, + ImputationDiD, + ImputationDiDResults, + imputation_did, +) + +# ============================================================================= +# Shared test data generation +# ============================================================================= + + +def generate_test_data( + n_units: int = 100, + n_periods: int = 10, + treatment_effect: float = 2.0, + never_treated_frac: float = 0.3, + dynamic_effects: bool = True, + seed: int = 42, +) -> pd.DataFrame: + """Generate synthetic staggered adoption data for testing.""" + rng = np.random.default_rng(seed) + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + n_never = int(n_units * never_treated_frac) + n_treated = n_units - n_never + + cohort_periods = np.array([3, 5, 7]) + first_treat = np.zeros(n_units, dtype=int) + if n_treated > 0: + cohort_assignments = rng.choice(len(cohort_periods), size=n_treated) + first_treat[n_never:] = cohort_periods[cohort_assignments] + + first_treat_expanded = np.repeat(first_treat, n_periods) + + unit_fe = rng.standard_normal(n_units) * 2.0 + time_fe = np.linspace(0, 1, n_periods) + + unit_fe_expanded = np.repeat(unit_fe, n_periods) + time_fe_expanded = np.tile(time_fe, n_units) + + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + relative_time = times - first_treat_expanded + + if dynamic_effects: + dynamic_mult = 1 + 0.1 * np.maximum(relative_time, 0) + else: + dynamic_mult = np.ones_like(relative_time, dtype=float) + + effect = treatment_effect * dynamic_mult + + outcomes = ( + unit_fe_expanded + time_fe_expanded + effect * post + rng.standard_normal(len(units)) * 0.5 + ) + + return pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded, + } + ) + + +# ============================================================================= +# TestImputationDiD +# ============================================================================= + + +class TestImputationDiD: + """Tests for ImputationDiD estimator.""" + + def test_basic_fit(self): + """Test basic model fitting.""" + data = generate_test_data() + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert est.is_fitted_ + assert isinstance(results, ImputationDiDResults) + assert results.overall_att is not None + assert results.overall_se > 0 + assert results.n_treated_obs > 0 + assert results.n_untreated_obs > 0 + assert results.n_treated_units > 0 + assert results.n_control_units > 0 + assert len(results.groups) == 3 + assert len(results.time_periods) == 10 + + def test_positive_treatment_effect(self): + """Test recovery of positive treatment effect.""" + data = generate_test_data(treatment_effect=3.0, seed=123) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert results.overall_att > 0 + # Effect should be close to 3.0 (dynamic effects add some) + assert abs(results.overall_att - 3.0) < 2 * results.overall_se + 1.5 + + def test_zero_treatment_effect(self): + """Test with no treatment effect.""" + data = generate_test_data(treatment_effect=0.0, seed=456) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert abs(results.overall_att) < 3 * results.overall_se + 0.5 + + def test_aggregate_simple(self): + """Test that default aggregate computes overall ATT.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert results.overall_att is not None + assert results.overall_se > 0 + assert results.event_study_effects is None + assert results.group_effects is None + + def test_aggregate_event_study(self): + """Test event study aggregation.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + assert results.event_study_effects is not None + assert len(results.event_study_effects) > 0 + assert results.group_effects is None + + for h, eff in results.event_study_effects.items(): + assert "effect" in eff + assert "se" in eff + assert "t_stat" in eff + assert "p_value" in eff + assert "conf_int" in eff + assert "n_obs" in eff + + def test_aggregate_group(self): + """Test group aggregation.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + assert results.group_effects is not None + assert len(results.group_effects) == 3 # 3 cohorts + assert results.event_study_effects is None + + for g, eff in results.group_effects.items(): + assert "effect" in eff + assert "se" in eff + assert eff["se"] > 0 + + def test_aggregate_all(self): + """Test 'all' aggregation computes both event study and group.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", + ) + + assert results.event_study_effects is not None + assert results.group_effects is not None + + def test_covariates(self): + """Test estimation with covariates.""" + data = generate_test_data() + rng = np.random.default_rng(99) + data["x1"] = rng.standard_normal(len(data)) + data["x2"] = rng.standard_normal(len(data)) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], + ) + + assert results.overall_att is not None + assert results.overall_se > 0 + + def test_anticipation(self): + """Test anticipation parameter.""" + data = generate_test_data() + + est0 = ImputationDiD(anticipation=0) + results0 = est0.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + est1 = ImputationDiD(anticipation=1) + results1 = est1.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + # With anticipation=1, more obs are treated, fewer untreated + assert results1.n_treated_obs > results0.n_treated_obs + + # Reference period changes + ref0 = [h for h, e in results0.event_study_effects.items() if e.get("n_obs", 1) == 0] + ref1 = [h for h, e in results1.event_study_effects.items() if e.get("n_obs", 1) == 0] + assert -1 in ref0 + assert -2 in ref1 + + def test_balance_e(self): + """Test balance_e restricts event study to balanced cohorts.""" + data = generate_test_data() + + est = ImputationDiD() + results_unbal = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + results_bal = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=2, + ) + + # Balanced should have same or fewer horizons + assert len(results_bal.event_study_effects) <= len(results_unbal.event_study_effects) + 5 + + def test_horizon_max(self): + """Test horizon_max caps event study horizons.""" + data = generate_test_data() + + est = ImputationDiD(horizon_max=3) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + for h in results.event_study_effects: + if results.event_study_effects[h].get("n_obs", 0) > 0: + assert abs(h) <= 3 + + def test_summary(self): + """Test summary output.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", + ) + + summary = results.summary() + assert "Imputation DiD" in summary + assert "ATT" in summary + assert "Event Study" in summary + assert "Group" in summary + + def test_to_dataframe_observation(self): + """Test to_dataframe at observation level.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + df = results.to_dataframe("observation") + assert "tau_hat" in df.columns + assert "weight" in df.columns + assert len(df) == results.n_treated_obs + + def test_to_dataframe_event_study(self): + """Test to_dataframe at event study level.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + df = results.to_dataframe("event_study") + assert "relative_period" in df.columns + assert "effect" in df.columns + assert "se" in df.columns + + def test_to_dataframe_group(self): + """Test to_dataframe at group level.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + df = results.to_dataframe("group") + assert "group" in df.columns + assert len(df) == 3 + + def test_to_dataframe_errors(self): + """Test to_dataframe raises on invalid level.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + with pytest.raises(ValueError, match="Unknown level"): + results.to_dataframe("invalid") + + with pytest.raises(ValueError, match="Event study effects not computed"): + results.to_dataframe("event_study") + + def test_get_params(self): + """Test get_params returns all constructor parameters.""" + est = ImputationDiD( + anticipation=1, + alpha=0.10, + n_bootstrap=100, + seed=42, + horizon_max=5, + aux_partition="cohort", + ) + params = est.get_params() + + assert params["anticipation"] == 1 + assert params["alpha"] == 0.10 + assert params["n_bootstrap"] == 100 + assert params["seed"] == 42 + assert params["horizon_max"] == 5 + assert params["aux_partition"] == "cohort" + assert params["cluster"] is None + assert params["rank_deficient_action"] == "warn" + + def test_set_params(self): + """Test set_params modifies attributes.""" + est = ImputationDiD() + est.set_params(alpha=0.10, anticipation=2) + + assert est.alpha == 0.10 + assert est.anticipation == 2 + + def test_set_params_unknown(self): + """Test set_params raises on unknown parameter.""" + est = ImputationDiD() + with pytest.raises(ValueError, match="Unknown parameter"): + est.set_params(nonexistent=True) + + def test_missing_columns(self): + """Test error on missing columns.""" + data = generate_test_data() + + est = ImputationDiD() + with pytest.raises(ValueError, match="Missing columns"): + est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="nonexistent", + ) + + def test_significance_properties(self): + """Test is_significant and significance_stars properties.""" + data = generate_test_data(treatment_effect=5.0) + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert results.is_significant + assert results.significance_stars in ("***", "**", "*", ".") + + def test_repr(self): + """Test string representation.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + r = repr(results) + assert "ImputationDiDResults" in r + assert "ATT=" in r + + def test_convenience_function(self): + """Test imputation_did convenience function.""" + data = generate_test_data() + results = imputation_did( + data, + "outcome", + "unit", + "time", + "first_treat", + aggregate="event_study", + ) + + assert isinstance(results, ImputationDiDResults) + assert results.event_study_effects is not None + + def test_convenience_function_kwargs(self): + """Test imputation_did passes kwargs to constructor.""" + data = generate_test_data() + results = imputation_did( + data, + "outcome", + "unit", + "time", + "first_treat", + alpha=0.10, + ) + + assert results.alpha == 0.10 + + def test_unbalanced_panel(self): + """Test with unbalanced panel (some units missing periods).""" + data = generate_test_data(seed=99) + rng = np.random.default_rng(99) + + # Drop some observations randomly + keep = rng.random(len(data)) > 0.1 + data_unbal = data[keep].reset_index(drop=True) + + est = ImputationDiD() + results = est.fit( + data_unbal, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert results.overall_att is not None + assert results.overall_se > 0 + + def test_balance_e_checks_pre_treatment_periods(self): + """balance_e should drop cohorts missing pre-treatment observations.""" + # Cohort A (first_treat=4): units 0-4, all periods 0-7 + # rel_times: -4, -3, -2, -1, 0, 1, 2, 3 + # Cohort B (first_treat=6): units 5-9, all periods 0-7 EXCEPT time=4 + # rel_times: -6, -5, -4, -3, -1, 0, 1 (missing -2) + # Never-treated: units 10-14, all periods + # + # horizon_max=1 caps post-treatment to {0,1} so both cohorts can + # cover the required post-treatment range. Without it, the union of + # all_horizons includes h=2,3 which cohort B can't reach (max h=1). + rows = [] + rng = np.random.default_rng(123) + + # Cohort A: complete panel + for u in range(5): + for t in range(8): + y = u * 0.5 + t * 0.1 + (3.0 if t >= 4 else 0.0) + rows.append( + { + "unit": u, + "time": t, + "first_treat": 4, + "outcome": y + rng.normal(0, 0.01), + } + ) + + # Cohort B: missing time=4 (which is rel_time = 4 - 6 = -2) + for u in range(5, 10): + for t in range(8): + if t == 4: + continue # drop => missing rel_time=-2 + y = u * 0.5 + t * 0.1 + (3.0 if t >= 6 else 0.0) + rows.append( + { + "unit": u, + "time": t, + "first_treat": 6, + "outcome": y + rng.normal(0, 0.01), + } + ) + + # Never-treated + for u in range(10, 15): + for t in range(8): + y = u * 0.5 + t * 0.1 + rows.append( + { + "unit": u, + "time": t, + "first_treat": 0, + "outcome": y + rng.normal(0, 0.01), + } + ) + + data = pd.DataFrame(rows) + est = ImputationDiD(horizon_max=1) + + # balance_e=2, horizon_max=1: required = {-2,-1,0,1} + # Cohort B missing -2 => should be dropped + results_bal2 = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=2, + ) + + # balance_e=1, horizon_max=1: required = {-1,0,1} + # Both cohorts have -1 => both kept + results_bal1 = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=1, + ) + + # Cohort B dropped at balance_e=2 => fewer obs at horizon 0 + n_obs_bal2_h0 = results_bal2.event_study_effects[0]["n_obs"] + n_obs_bal1_h0 = results_bal1.event_study_effects[0]["n_obs"] + assert n_obs_bal2_h0 < n_obs_bal1_h0, ( + f"balance_e=2 should drop cohort B (missing rel_time=-2), " + f"got n_obs={n_obs_bal2_h0} vs {n_obs_bal1_h0}" + ) + + +# ============================================================================= +# TestImputationDiDResults +# ============================================================================= + + +class TestImputationDiDResults: + """Tests for ImputationDiDResults.""" + + def test_pretrend_test(self): + """Test pre-trend test on data with parallel trends.""" + data = generate_test_data(dynamic_effects=False, seed=77, n_units=200) + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + pt = results.pretrend_test() + assert "f_stat" in pt + assert "p_value" in pt + assert "n_leads" in pt + assert pt["n_leads"] > 0 + + # Under parallel trends, should not reject + assert pt["p_value"] > 0.01 + + def test_pretrend_with_violation(self): + """Test pre-trend test detects trend violation.""" + data = generate_test_data(seed=88, n_units=200) + + # Add a pre-treatment trend for treated units + rng = np.random.default_rng(88) + for idx in data.index: + if data.loc[idx, "first_treat"] > 0: + t = data.loc[idx, "time"] + ft = data.loc[idx, "first_treat"] + if t < ft: + data.loc[idx, "outcome"] += 0.5 * (t - ft) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + pt = results.pretrend_test() + # With pre-trend violation, should reject (low p-value) + assert pt["p_value"] < 0.10 + + def test_pretrend_unbalanced_panel(self): + """Test pretrend_test uses iterative demeaning for unbalanced panels.""" + data = generate_test_data(dynamic_effects=False, seed=77, n_units=200) + # Make unbalanced by dropping ~15% of observations + rng = np.random.default_rng(77) + keep = rng.random(len(data)) > 0.15 + data_unbal = data[keep].reset_index(drop=True) + + est = ImputationDiD() + results = est.fit( + data_unbal, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + pt = results.pretrend_test() + assert pt["n_leads"] > 0 + # Under parallel trends, should not reject + assert pt["p_value"] > 0.01 + + def test_pretrend_n_leads(self): + """Test pre-trend test with specified number of leads.""" + data = generate_test_data(n_units=200, seed=55) + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + pt = results.pretrend_test(n_leads=2) + assert pt["n_leads"] == 2 + + +# ============================================================================= +# TestImputationVariance +# ============================================================================= + + +class TestImputationVariance: + """Tests for conservative variance estimation (Theorem 3).""" + + def test_se_positive(self): + """Test that SE is positive.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert results.overall_se > 0 + + def test_se_positive_event_study(self): + """Test that event study SEs are positive.""" + data = generate_test_data() + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + for h, eff in results.event_study_effects.items(): + if eff.get("n_obs", 0) > 0 and np.isfinite(eff["se"]): + assert eff["se"] > 0 + + def test_aux_partition_cohort_horizon(self): + """Test cohort_horizon partition produces valid SEs.""" + data = generate_test_data() + est = ImputationDiD(aux_partition="cohort_horizon") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert results.overall_se > 0 + + def test_aux_partition_cohort(self): + """Test cohort partition produces valid SEs.""" + data = generate_test_data() + est = ImputationDiD(aux_partition="cohort") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert results.overall_se > 0 + + def test_aux_partition_horizon(self): + """Test horizon partition produces valid SEs.""" + data = generate_test_data() + est = ImputationDiD(aux_partition="horizon") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert results.overall_se > 0 + + def test_coarser_partition_more_conservative(self): + """Test that coarser partition gives more conservative (larger) SEs.""" + data = generate_test_data(n_units=200, seed=42) + + est_fine = ImputationDiD(aux_partition="cohort_horizon") + results_fine = est_fine.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + est_coarse = ImputationDiD(aux_partition="cohort") + results_coarse = est_coarse.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Coarser partition should give >= SE (approximately) + # Allow small tolerance for numerical issues + assert results_coarse.overall_se >= results_fine.overall_se * 0.95 + + def test_invalid_aux_partition(self): + """Test that invalid aux_partition raises ValueError.""" + with pytest.raises(ValueError, match="aux_partition"): + ImputationDiD(aux_partition="invalid") + + def test_sparse_solver_matches_dense(self): + """Test that sparse solver produces finite SEs with covariates.""" + data = generate_test_data(n_units=100, n_periods=10, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.standard_normal(len(data)) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1"], + ) + + assert np.isfinite(results.overall_se) + assert results.overall_se > 0 + + def test_sparse_solver_dense_fallback(self): + """Test that dense fallback produces finite SE when spsolve fails.""" + import unittest.mock + + data = generate_test_data(n_units=80, n_periods=8, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.standard_normal(len(data)) + + est = ImputationDiD() + + # Monkey-patch spsolve to force fallback to dense lstsq + with unittest.mock.patch( + "diff_diff.imputation.spsolve", side_effect=RuntimeError("test failure") + ): + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1"], + ) + + assert np.isfinite(results.overall_se) + assert results.overall_se > 0 + + +# ============================================================================= +# TestImputationBootstrap +# ============================================================================= + + +class TestImputationBootstrap: + """Tests for bootstrap inference.""" + + def test_basic_bootstrap(self, ci_params): + """Test basic bootstrap inference.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert results.bootstrap_results is not None + assert isinstance(results.bootstrap_results, ImputationBootstrapResults) + assert results.bootstrap_results.n_bootstrap == n_boot + assert results.bootstrap_results.overall_att_se > 0 + + def test_bootstrap_reproducibility(self, ci_params): + """Test that same seed gives same results.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + + est1 = ImputationDiD(n_bootstrap=n_boot, seed=42) + r1 = est1.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") + + est2 = ImputationDiD(n_bootstrap=n_boot, seed=42) + r2 = est2.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") + + assert r1.overall_se == r2.overall_se + + def test_bootstrap_different_seeds(self, ci_params): + """Test that different seeds give different results.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + + est1 = ImputationDiD(n_bootstrap=n_boot, seed=42) + r1 = est1.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") + + est2 = ImputationDiD(n_bootstrap=n_boot, seed=99) + r2 = est2.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") + + # Results should differ (at least slightly) + assert r1.overall_se != r2.overall_se + + def test_bootstrap_event_study(self, ci_params): + """Test bootstrap with event study aggregation.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + br = results.bootstrap_results + assert br.event_study_ses is not None + assert len(br.event_study_ses) > 0 + + def test_bootstrap_group(self, ci_params): + """Test bootstrap with group aggregation.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + br = results.bootstrap_results + assert br.group_ses is not None + assert len(br.group_ses) == 3 + + def test_bootstrap_balance_e_consistency(self, ci_params): + """Test bootstrap event study respects balance_e filtering.""" + data = generate_test_data(n_units=150, seed=42) + n_boot = ci_params.bootstrap(50) + + # Run WITH balance_e + est_bal = ImputationDiD(n_bootstrap=n_boot, seed=42) + results_bal = est_bal.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=2, + ) + + # Run WITHOUT balance_e + est_nobal = ImputationDiD(n_bootstrap=n_boot, seed=42) + results_nobal = est_nobal.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + assert results_bal.bootstrap_results is not None + assert results_bal.bootstrap_results.event_study_ses is not None + + # Verify SEs are finite + for h in results_bal.event_study_effects: + eff = results_bal.event_study_effects[h] + if eff.get("n_obs", 0) > 0 and np.isfinite(eff["effect"]): + if h in results_bal.bootstrap_results.event_study_ses: + assert np.isfinite(results_bal.bootstrap_results.event_study_ses[h]) + + # Verify balance_e changed bootstrap SEs at some horizon + if results_nobal.bootstrap_results is not None: + bal_ses = results_bal.bootstrap_results.event_study_ses + nobal_ses = results_nobal.bootstrap_results.event_study_ses + shared_h = set(bal_ses.keys()) & set(nobal_ses.keys()) + any_different = any( + not np.isclose(bal_ses[h], nobal_ses[h], rtol=0.05) + for h in shared_h + if np.isfinite(bal_ses[h]) and np.isfinite(nobal_ses[h]) + ) + assert any_different, "balance_e should change bootstrap SEs for at least one horizon" + + def test_bootstrap_p_value_significance(self, ci_params): + """Test bootstrap p-value for significant effect.""" + data = generate_test_data(treatment_effect=5.0, n_units=200) + n_boot = ci_params.bootstrap(199, min_n=99) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Strong effect should be significant + assert results.overall_p_value < 0.05 + + def test_bootstrap_zero_noise_near_zero_se(self, ci_params): + """Bootstrap SE ~ 0 when influence function is zero (constant effect, no noise).""" + n_units, n_periods = 40, 8 + true_effect = 3.0 + rows = [] + for i in range(n_units): + ft = 4 if i < 20 else 0 + unit_fe = i * 0.5 + for t in range(n_periods): + y = unit_fe + t * 0.1 # exact FE, no noise + if ft > 0 and t >= ft: + y += true_effect + rows.append({"unit": i, "time": t, "outcome": y, "first_treat": ft}) + data = pd.DataFrame(rows) + + n_boot = ci_params.bootstrap(99) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert abs(results.overall_att - true_effect) < 1e-8 + assert results.bootstrap_results is not None + # With zero noise, influence function sums are ~0, so SE should be ~0 + assert results.bootstrap_results.overall_att_se < 0.01 + + def test_bootstrap_percentile_ci(self, ci_params): + """Test that bootstrap CIs use percentile method, not normal approx.""" + data = generate_test_data(dynamic_effects=False, seed=42) + n_boot = ci_params.bootstrap(50) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + br = results.bootstrap_results + assert br is not None + + # Verify CIs match percentile of bootstrap distribution + dist = br.bootstrap_distribution + expected_lower = float(np.percentile(dist, 2.5)) + expected_upper = float(np.percentile(dist, 97.5)) + np.testing.assert_allclose(br.overall_att_ci[0], expected_lower, rtol=1e-10) + np.testing.assert_allclose(br.overall_att_ci[1], expected_upper, rtol=1e-10) + + +# ============================================================================= +# TestImputationVsOtherEstimators +# ============================================================================= + + +class TestImputationVsOtherEstimators: + """Cross-validation with CallawaySantAnna and SunAbraham.""" + + def test_similar_point_estimates_vs_cs(self): + """Test that point estimates are similar to CallawaySantAnna.""" + from diff_diff import CallawaySantAnna + + data = generate_test_data(n_units=200, treatment_effect=2.0, seed=42, dynamic_effects=False) + + imp_est = ImputationDiD() + imp_results = imp_est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + cs = CallawaySantAnna() + cs_results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Point estimates should be reasonably close + cs_att = cs_results.overall_att + imp_att = imp_results.overall_att + assert abs(imp_att - cs_att) < 1.0 + + def test_similar_point_estimates_vs_sa(self): + """Test that point estimates are similar to SunAbraham.""" + from diff_diff import SunAbraham + + data = generate_test_data(n_units=200, treatment_effect=2.0, seed=42, dynamic_effects=False) + + imp_est = ImputationDiD() + imp_results = imp_est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + sa = SunAbraham() + sa_results = sa.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Point estimates should be reasonably close + assert abs(imp_results.overall_att - sa_results.overall_att) < 1.0 + + def test_shorter_cis_under_homogeneous_effects(self): + """Under homogeneous effects, imputation CIs should be shorter.""" + data = generate_test_data( + n_units=300, + treatment_effect=2.0, + seed=42, + dynamic_effects=False, + ) + + imp_est = ImputationDiD() + imp_results = imp_est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + from diff_diff import CallawaySantAnna + + cs = CallawaySantAnna() + cs_results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + imp_ci_width = imp_results.overall_conf_int[1] - imp_results.overall_conf_int[0] + cs_ci_width = cs_results.overall_conf_int[1] - cs_results.overall_conf_int[0] + + # Imputation CIs should be shorter (or at least not much longer) + assert imp_ci_width < cs_ci_width * 1.5 + + +# ============================================================================= +# TestImputationEdgeCases +# ============================================================================= + + +class TestImputationEdgeCases: + """Tests for edge cases.""" + + def test_single_cohort(self): + """Test with a single treatment cohort.""" + rng = np.random.default_rng(42) + n_units = 50 + n_periods = 8 + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + first_treat = np.zeros(n_units, dtype=int) + first_treat[25:] = 4 # Single cohort at period 4 + + first_treat_exp = np.repeat(first_treat, n_periods) + post = (times >= first_treat_exp) & (first_treat_exp > 0) + + outcomes = ( + np.repeat(rng.standard_normal(n_units) * 2, n_periods) + + np.tile(np.linspace(0, 1, n_periods), n_units) + + 2.0 * post + + rng.standard_normal(len(units)) * 0.5 + ) + + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_exp, + } + ) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert len(results.groups) == 1 + assert results.overall_se > 0 + assert abs(results.overall_att - 2.0) < 1.0 + + def test_no_never_treated(self): + """Test with no never-treated units (Proposition 5).""" + data = generate_test_data(never_treated_frac=0.0, seed=42) + + est = ImputationDiD() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + # Should still estimate + assert results.overall_att is not None + assert results.overall_se > 0 + + # Proposition 5: long-run horizons should be NaN + prop5_nans = [ + h + for h, eff in results.event_study_effects.items() + if np.isnan(eff["effect"]) and eff.get("n_obs", 0) > 0 + ] + assert len(prop5_nans) > 0, "Should have Prop 5 NaN horizons" + + # Check all inference fields are NaN for Prop 5 horizons + for h in prop5_nans: + eff = results.event_study_effects[h] + assert np.isnan(eff["se"]) + assert np.isnan(eff["t_stat"]) + assert np.isnan(eff["p_value"]) + assert np.isnan(eff["conf_int"][0]) + assert np.isnan(eff["conf_int"][1]) + + def test_two_periods(self): + """Test with just two periods (basic 2x2 DiD).""" + rng = np.random.default_rng(42) + n_units = 60 + n_periods = 2 + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + first_treat = np.zeros(n_units, dtype=int) + first_treat[30:] = 1 # Treated in period 1 + + first_treat_exp = np.repeat(first_treat, n_periods) + post = (times >= first_treat_exp) & (first_treat_exp > 0) + + outcomes = ( + np.repeat(rng.standard_normal(n_units) * 2, n_periods) + + 3.0 * post + + rng.standard_normal(len(units)) * 0.5 + ) + + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_exp, + } + ) + + est = ImputationDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert abs(results.overall_att - 3.0) < 1.0 + + def test_rank_deficiency_warn(self): + """Test rank_deficient_action='warn' doesn't error.""" + data = generate_test_data() + est = ImputationDiD(rank_deficient_action="warn") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert results.overall_se > 0 + + def test_rank_deficiency_error(self): + """Test rank_deficient_action='error' works.""" + est = ImputationDiD(rank_deficient_action="error") + # Should work fine on good data + data = generate_test_data() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert results.overall_se > 0 + + def test_invalid_rank_deficient_action(self): + """Test invalid rank_deficient_action raises ValueError.""" + with pytest.raises(ValueError, match="rank_deficient_action"): + ImputationDiD(rank_deficient_action="ignore") + + def test_always_treated_warning(self): + """Test warning for units treated in all periods.""" + rng = np.random.default_rng(42) + n_units = 40 + n_periods = 6 + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + first_treat = np.zeros(n_units, dtype=int) + first_treat[10:20] = 0 # period 0 = always treated + first_treat[20:] = 3 # treated at period 3 + + # Make some units treated in all periods + first_treat[10:20] = 0 # Never treated (actually) + # To make always-treated: first_treat <= min_time (0) + first_treat[0:5] = 0 # These are never-treated + first_treat[5:10] = -1 # Treated before panel starts! + + first_treat_exp = np.repeat(first_treat, n_periods) + post = (times >= first_treat_exp) & (first_treat_exp > 0) & (first_treat_exp != np.inf) + + outcomes = ( + np.repeat(rng.standard_normal(n_units) * 2, n_periods) + + 2.0 * post + + rng.standard_normal(len(units)) * 0.5 + ) + + # Fix: first_treat with -1 won't trigger the never_treated check properly + # Let's use first_treat = 0 for some units to trigger always-treated + first_treat_2 = np.zeros(n_units, dtype=int) + first_treat_2[:10] = 0 # never treated + first_treat_2[10:15] = 0 # also never treated (we need >= 1 always-treated) + first_treat_2[15:] = 3 + # Actually, to trigger always-treated, we need first_treat <= min(time) = 0 + # But first_treat == 0 means never-treated in the code + # We need first_treat > 0 but <= min(time) + # min(time) = 0, so first_treat must be <= 0 and > 0, impossible + # Let's start times at 1 + times_shifted = np.tile(np.arange(1, n_periods + 1), n_units) + + first_treat_3 = np.zeros(n_units, dtype=int) + first_treat_3[:10] = 0 # never treated + first_treat_3[10:15] = 1 # treated from the very beginning (always treated) + first_treat_3[15:] = 4 + + first_treat_exp_3 = np.repeat(first_treat_3, n_periods) + post_3 = (times_shifted >= first_treat_exp_3) & (first_treat_exp_3 > 0) + + outcomes_3 = ( + np.repeat(rng.standard_normal(n_units) * 2, n_periods) + + 2.0 * post_3 + + rng.standard_normal(len(units)) * 0.5 + ) + + data = pd.DataFrame( + { + "unit": units, + "time": times_shifted, + "outcome": outcomes_3, + "first_treat": first_treat_exp_3, + } + ) + + est = ImputationDiD() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Should have issued a warning about always-treated + always_treated_warnings = [ + x for x in w if "treated in all observed periods" in str(x.message) + ] + assert len(always_treated_warnings) > 0 + + def test_no_treated_units(self): + """Test error when no treated units.""" + data = generate_test_data() + data["first_treat"] = 0 # All never-treated + + est = ImputationDiD() + with pytest.raises(ValueError, match="No treated"): + est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + def test_nan_propagation_all_nan_horizon(self): + """Test NaN propagation when all tau_hat at a horizon are NaN.""" + data = generate_test_data(never_treated_frac=0.0, seed=42) + + est = ImputationDiD() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + # Check that NaN horizons have all-NaN inference + for h, eff in results.event_study_effects.items(): + if eff.get("n_obs", 0) > 0 and np.isnan(eff["effect"]): + assert np.isnan(eff["se"]) + assert np.isnan(eff["t_stat"]) + assert np.isnan(eff["p_value"]) + assert np.isnan(eff["conf_int"][0]) + assert np.isnan(eff["conf_int"][1]) + + def test_summary_not_fitted(self): + """Test error when calling summary before fit.""" + est = ImputationDiD() + with pytest.raises(RuntimeError, match="must be fitted"): + est.summary() + + def test_rank_condition_missing_untreated_period(self): + """Test warning when a post-treatment period has no untreated units.""" + # Construct data where ALL units are treated from period 2 onward, + # so periods 2+ have no untreated observations + rng = np.random.default_rng(42) + n_units, n_periods = 20, 5 + rows = [] + for i in range(n_units): + ft = 2 # all units treated at period 2 + for t in range(n_periods): + y = rng.standard_normal() + i * 0.1 + t * 0.05 + if t >= ft: + y += 1.0 # treatment effect + rows.append( + { + "unit": i, + "time": t, + "outcome": y, + "first_treat": ft, + } + ) + data = pd.DataFrame(rows) + + est = ImputationDiD(rank_deficient_action="warn") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + rank_warnings = [x for x in w if "Rank condition" in str(x.message)] + assert len(rank_warnings) > 0, "Should warn about rank condition violation" + + # Affected horizons should have NaN effects (periods with no untreated units) + if results.event_study_effects: + nan_effects = [ + h + for h, d in results.event_study_effects.items() + if np.isnan(d["effect"]) and d.get("n_obs", 1) > 0 + ] + assert len(nan_effects) > 0, "Some horizons should have NaN effects" + + def test_rank_condition_error_mode(self): + """Test error raised when rank condition fails with action='error'.""" + # Same setup as test_rank_condition_missing_untreated_period + rng = np.random.default_rng(42) + n_units, n_periods = 20, 5 + rows = [] + for i in range(n_units): + ft = 2 + for t in range(n_periods): + y = rng.standard_normal() + i * 0.1 + t * 0.05 + if t >= ft: + y += 1.0 + rows.append( + { + "unit": i, + "time": t, + "outcome": y, + "first_treat": ft, + } + ) + data = pd.DataFrame(rows) + + est = ImputationDiD(rank_deficient_action="error") + with pytest.raises(ValueError, match="Rank condition"): + est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + def test_bootstrap_cluster_not_unit(self, ci_params): + """Test bootstrap uses cluster column when cluster != unit.""" + data = generate_test_data(n_units=100, n_periods=8, seed=42) + # Create cluster column grouping every 5 units + unit_to_cluster = {u: u // 5 for u in data["unit"].unique()} + data["cluster_id"] = data["unit"].map(unit_to_cluster) + + n_boot = ci_params.bootstrap(99, min_n=49) + est = ImputationDiD(cluster="cluster_id", n_bootstrap=n_boot, seed=42) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert results.bootstrap_results is not None + assert results.bootstrap_results.overall_att_se > 0 + + # Bootstrap SE with cluster should differ from unit-level bootstrap + est_unit = ImputationDiD(n_bootstrap=n_boot, seed=42) + results_unit = est_unit.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + assert ( + results.bootstrap_results.overall_att_se + != results_unit.bootstrap_results.overall_att_se + ) + + def test_bootstrap_invalid_cluster_column(self): + """Test error when cluster column doesn't exist.""" + data = generate_test_data(n_units=50, seed=42) + est = ImputationDiD(cluster="nonexistent_col") + with pytest.raises(ValueError, match="not found"): + est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + def test_plot_reference_with_anticipation(self): + """Test event study plot detects reference period with anticipation.""" + data = generate_test_data(n_units=100, n_periods=10, seed=42) + est = ImputationDiD(anticipation=1) + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + # Reference should be at -2 (= -1 - anticipation) + assert -2 in results.event_study_effects + assert results.event_study_effects[-2]["n_obs"] == 0 # reference marker + + # Test that plot_event_study auto-detects it + pytest.importorskip("matplotlib") + from diff_diff import plot_event_study + + fig = plot_event_study(results) + assert fig is not None + + def test_overall_se_with_partial_nan_tau_hat(self): + """Test overall SE uses finite-only weights when some tau_hat are NaN.""" + # Create staggered data: cohort A treated at t=2, cohort B never-treated + # but drop all never-treated obs at t=5, so t=5 time FE is unidentified + # -> tau_hat for (cohort A, t=5) will be NaN + rng = np.random.default_rng(42) + n_units, n_periods = 40, 6 + rows = [] + for i in range(n_units): + if i < 20: + ft = 2 # early-treated + else: + ft = 99 # never-treated + for t in range(n_periods): + # Drop never-treated at t=5 to create unidentified time FE + if ft == 99 and t == 5: + continue + y = rng.standard_normal() + i * 0.1 + t * 0.05 + if t >= ft: + y += 1.0 + rows.append( + { + "unit": i, + "time": t, + "outcome": y, + "first_treat": ft, + } + ) + data = pd.DataFrame(rows) + + est = ImputationDiD(rank_deficient_action="silent") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + tau_hat = results.treatment_effects["tau_hat"] + n_nan = tau_hat.isna().sum() + n_finite = tau_hat.notna().sum() + + # Verify the scenario actually produces partial NaN + assert n_nan > 0, "Expected some NaN tau_hat (missing time FE at t=5)" + assert n_finite > 0, "Expected some finite tau_hat" + + # Partial NaN case: SE should be finite (computed from finite-only weights) + assert np.isfinite( + results.overall_se + ), f"overall_se should be finite with {n_finite} finite and {n_nan} NaN tau_hat" + assert np.isfinite(results.overall_att) + + def test_iterative_demean_balanced_matches_one_pass(self): + """Test _iterative_demean matches one-pass for balanced panels.""" + rng = np.random.default_rng(42) + n_units, n_periods = 20, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + vals = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(vals)) + + result_iter = ImputationDiD._iterative_demean(vals, units, times, idx) + + # One-pass for balanced panel + s = pd.DataFrame({"val": vals, "unit": units, "time": times}) + gm = s["val"].mean() + um = s.groupby("unit")["val"].transform("mean").values + tm = s.groupby("time")["val"].transform("mean").values + result_onepass = vals - um - tm + gm + + np.testing.assert_allclose(result_iter, result_onepass, atol=1e-8) + + def test_unbalanced_panel_fe_correctness(self): + """Test FE estimates match OLS for unbalanced panel.""" + # Create small unbalanced panel with known FE structure + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + unit_fe_true = rng.standard_normal(n_units) * 2.0 + time_fe_true = np.linspace(0, 1, n_periods) + + rows = [] + for i in range(n_units): + for t in range(n_periods): + # Drop ~20% of obs to make unbalanced + if rng.random() < 0.2: + continue + y = unit_fe_true[i] + time_fe_true[t] + rng.standard_normal() * 0.01 + rows.append( + { + "unit": i, + "time": t, + "outcome": y, + "first_treat": n_periods, # all never-treated -> Omega_0 + } + ) + + df_0 = pd.DataFrame(rows) + + # Compute FE via iterative method (what we're testing) + est = ImputationDiD() + unit_fe_iter, time_fe_iter = est._iterative_fe( + df_0["outcome"].values, + df_0["unit"].values, + df_0["time"].values, + df_0.index, + ) + + # Compute exact OLS FE via lstsq with dummy variables + unique_units = sorted(df_0["unit"].unique()) + unique_times = sorted(df_0["time"].unique()) + n = len(df_0) + n_u = len(unique_units) + n_t = len(unique_times) + u_map = {u: i for i, u in enumerate(unique_units)} + t_map = {t: i for i, t in enumerate(unique_times)} + + X = np.zeros((n, 1 + (n_u - 1) + (n_t - 1))) + X[:, 0] = 1.0 # intercept + for j in range(n): + uid = u_map[df_0["unit"].iloc[j]] + tid = t_map[df_0["time"].iloc[j]] + if uid > 0: + X[j, uid] = 1.0 + if tid > 0: + X[j, n_u + tid - 1] = 1.0 + + beta_ols = np.linalg.lstsq(X, df_0["outcome"].values, rcond=None)[0] + + # Reconstruct OLS fitted values + intercept = beta_ols[0] + unit_fe_ols = {unique_units[0]: intercept} + for i in range(1, n_u): + unit_fe_ols[unique_units[i]] = intercept + beta_ols[i] + time_fe_ols = {unique_times[0]: 0.0} + for i in range(1, n_t): + time_fe_ols[unique_times[i]] = beta_ols[n_u + i - 1] + + # Compare fitted values (parameterization-invariant check) + for j in range(n): + u = df_0["unit"].iloc[j] + t = df_0["time"].iloc[j] + y_hat_iter = unit_fe_iter[u] + time_fe_iter[t] + y_hat_ols = unit_fe_ols[u] + time_fe_ols[t] + assert abs(y_hat_iter - y_hat_ols) < 1e-6, ( + f"Fitted values differ at unit={u}, time={t}: " + f"iterative={y_hat_iter:.8f} vs OLS={y_hat_ols:.8f}" + ) + + def test_non_constant_first_treat_warning(self): + """Warn when first_treat varies within a unit (violates absorbing treatment).""" + data = generate_test_data(dynamic_effects=False, seed=42) + # Corrupt first_treat for unit 0: make it vary across rows + bad_unit = data["unit"].unique()[0] + mask = data["unit"] == bad_unit + data.loc[mask & (data["time"] >= 5), "first_treat"] = 99 + + est = ImputationDiD() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + absorbing_warnings = [x for x in w if "non-constant" in str(x.message)] + assert len(absorbing_warnings) >= 1, "Expected warning about non-constant first_treat" + # Verify warning mentions the unit count and example + msg = str(absorbing_warnings[0].message) + assert "1 unit(s)" in msg + assert str(bad_unit) in msg + + # Behavioral assertion: estimator still produces results (warns, doesn't crash) + assert isinstance(results, ImputationDiDResults) + assert np.isfinite(results.overall_att) + + # Behavioral assertion: coercion applied — first_treat is now constant per unit + fit_df = est._fit_data["df"] + bad_rows = fit_df[fit_df["unit"] == bad_unit] + ft_vals = bad_rows["first_treat"].unique() + assert ( + len(ft_vals) == 1 + ), f"first_treat should be coerced to single value per unit, got {ft_vals}" + + def test_treatment_effects_weight_nan_consistency(self): + """Test that treatment_effects weights are 0 for NaN tau_hat and 1/n_valid for finite.""" + # Reuse the partial-NaN scenario from test_overall_se_with_partial_nan_tau_hat + rng = np.random.default_rng(42) + n_units, n_periods = 40, 6 + rows = [] + for i in range(n_units): + if i < 20: + ft = 2 # early-treated + else: + ft = 99 # never-treated + for t in range(n_periods): + # Drop never-treated at t=5 to create unidentified time FE + if ft == 99 and t == 5: + continue + y = rng.standard_normal() + i * 0.1 + t * 0.05 + if t >= ft: + y += 1.0 + rows.append({"unit": i, "time": t, "outcome": y, "first_treat": ft}) + data = pd.DataFrame(rows) + + est = ImputationDiD(rank_deficient_action="silent") + results = est.fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + te = results.treatment_effects + nan_rows = te[te["tau_hat"].isna()] + finite_rows = te[te["tau_hat"].notna()] + + # Verify scenario produces partial NaN + assert len(nan_rows) > 0 + assert len(finite_rows) > 0 + + # NaN tau_hat rows have weight 0 + assert (nan_rows["weight"] == 0.0).all(), "NaN tau_hat rows should have weight 0" + + # Finite weights sum to ~1.0 + assert abs(finite_rows["weight"].sum() - 1.0) < 1e-10, "Finite weights should sum to 1" + + # Each finite weight equals 1/n_finite + n_finite = len(finite_rows) + expected_weight = 1.0 / n_finite + np.testing.assert_allclose(finite_rows["weight"].values, expected_weight, rtol=1e-10) + + def test_rank_deficient_covariates_excluded_from_variance(self): + """Rank-deficient covariates are excluded from variance design matrices.""" + data = generate_test_data(n_units=80, n_periods=8, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.standard_normal(len(data)) + data["x2"] = 2.0 * data["x1"] # perfectly collinear + + est = ImputationDiD(rank_deficient_action="silent") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], + ) + + # SE should be finite (not blown up by singular design matrix) + assert np.isfinite(results.overall_se), "SE should be finite with rank-deficient covariates" + assert results.overall_se > 0 + + # Verify kept_cov_mask is stored and has one True + one False + mask = est._fit_data["kept_cov_mask"] + assert mask is not None + assert mask.sum() == 1, f"Expected 1 kept covariate, got {mask.sum()}" + assert len(mask) == 2 + assert (~mask).sum() == 1, "Expected 1 dropped covariate" + + def test_bootstrap_psi_precomputation_failure_warning(self, ci_params): + """Warning emitted and bootstrap skipped when psi precomputation fails.""" + data = generate_test_data(dynamic_effects=False, seed=42) + n_boot = ci_params.bootstrap(99) + est = ImputationDiD(n_bootstrap=n_boot, seed=42) + + # Monkey-patch to force failure + def failing_precompute(*args, **kwargs): + raise RuntimeError("test failure") + + est._precompute_bootstrap_psi = failing_precompute + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + psi_warnings = [x for x in w if "Bootstrap pre-computation failed" in str(x.message)] + assert len(psi_warnings) >= 1 + + # Behavioral assertion: bootstrap_results is None + assert results.bootstrap_results is None + # Analytical SE still present + assert results.overall_se > 0 + + def test_event_study_empty_after_filtering(self): + """Warn when balance_e/horizon_max filter out all treated horizons.""" + data = generate_test_data(dynamic_effects=False, seed=42) + # balance_e=100 requires cohorts to span [-100, max_h+1], which none do. + # All cohorts fail the balanced check, so all horizons have n_h=0. + est = ImputationDiD() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=100, + ) + empty_warnings = [x for x in w if "no horizons with observations" in str(x.message)] + assert len(empty_warnings) >= 1, "Expected warning about empty event study" + + # Only reference period should remain + ref_period = -1 + assert ref_period in results.event_study_effects + real_effects = { + h: v + for h, v in results.event_study_effects.items() + if h != ref_period and v.get("n_obs", 0) > 0 + } + assert len(real_effects) == 0 + + def test_balanced_cohort_mask_requires_negative_horizons(self): + """_compute_balanced_cohort_mask must check negative relative times.""" + cohort_rel_times = { + 5: {-2, -1, 0, 1, 2}, + 7: {-1, 0, 1, 2}, # missing -2 + } + df_treated = pd.DataFrame({"first_treat": [5, 5, 5, 7, 7, 7]}) + all_horizons = [0, 1, 2] + + # balance_e=2 requires {-2,-1,0,1,2}: only cohort 5 passes + mask2 = ImputationDiD._compute_balanced_cohort_mask( + df_treated, "first_treat", all_horizons, 2, cohort_rel_times + ) + assert mask2.tolist() == [True, True, True, False, False, False] + + # balance_e=1 requires {-1,0,1,2}: both pass + mask1 = ImputationDiD._compute_balanced_cohort_mask( + df_treated, "first_treat", all_horizons, 1, cohort_rel_times + ) + assert all(mask1)