Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/rust-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
uses: dtolnay/rust-toolchain@stable

- name: Install test dependencies
run: pip install pytest numpy pandas scipy
run: pip install pytest pytest-xdist numpy pandas scipy

- name: Build and install with maturin
run: |
Expand Down Expand Up @@ -121,14 +121,14 @@ jobs:
- name: Run tests with Rust backend (Unix)
if: runner.os != 'Windows'
working-directory: /tmp
run: DIFF_DIFF_BACKEND=rust pytest tests/ -x -q
run: DIFF_DIFF_BACKEND=rust pytest tests/ -q -n auto --dist worksteal

- name: Run tests with Rust backend (Windows)
if: runner.os == 'Windows'
working-directory: ${{ runner.temp }}
run: |
$env:DIFF_DIFF_BACKEND="rust"
pytest tests/ -x -q
pytest tests/ -q -n auto --dist worksteal
shell: pwsh

# Test pure Python fallback (without Rust extension)
Expand All @@ -144,12 +144,12 @@ jobs:
python-version: '3.11'

- name: Install dependencies
run: pip install numpy pandas scipy pytest
run: pip install numpy pandas scipy pytest pytest-xdist

- name: Verify pure Python mode
run: |
# Use PYTHONPATH to import directly (skips maturin build)
PYTHONPATH=. python -c "from diff_diff import HAS_RUST_BACKEND; print(f'HAS_RUST_BACKEND: {HAS_RUST_BACKEND}'); assert not HAS_RUST_BACKEND"

- name: Run tests in pure Python mode
run: PYTHONPATH=. DIFF_DIFF_BACKEND=python pytest tests/ -x -q --ignore=tests/test_rust_backend.py
run: PYTHONPATH=. DIFF_DIFF_BACKEND=python pytest tests/ -q --ignore=tests/test_rust_backend.py -n auto --dist worksteal
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ Tests mirror the source modules:
- `tests/test_pretrends.py` - Tests for pre-trends power analysis
- `tests/test_datasets.py` - Tests for dataset loading functions

Session-scoped `ci_params` fixture in `conftest.py` scales bootstrap iterations and TROP grid sizes in pure Python mode — use `ci_params.bootstrap(n)` and `ci_params.grid(values)` in new tests with `n_bootstrap >= 20`. For SE convergence tests (analytical vs bootstrap comparison), use `ci_params.bootstrap(n, min_n=199)` to ensure sufficient iterations.
Session-scoped `ci_params` fixture in `conftest.py` scales bootstrap iterations and TROP grid sizes in pure Python mode — use `ci_params.bootstrap(n)` and `ci_params.grid(values)` in new tests with `n_bootstrap >= 20`. For SE convergence tests (analytical vs bootstrap comparison), use `ci_params.bootstrap(n, min_n=199)` with a conditional tolerance: `threshold = 0.40 if n_boot < 100 else 0.15`. The `min_n` parameter is capped at 49 in pure Python mode to keep CI fast, so convergence tests use wider tolerances when running with fewer bootstrap iterations.

### Test Writing Guidelines

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pytest>=7.0",
"pytest-xdist>=3.0",
"pytest-cov>=4.0",
"black>=23.0",
"ruff>=0.1.0",
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def bootstrap(n: int, *, min_n: int = 11) -> int:

Use a larger min_n for tests comparing analytical vs bootstrap SEs,
which need more iterations for stable convergence.
In pure Python mode, min_n is capped at 49 to keep CI fast.
"""
if not _PURE_PYTHON_MODE or n <= 10:
return n
return min(n, max(min_n, int(math.sqrt(n) * 1.6)))
effective_min = min(min_n, 49)
return min(n, max(effective_min, int(math.sqrt(n) * 1.6)))

@staticmethod
def grid(values: list) -> list:
Expand Down
24 changes: 18 additions & 6 deletions tests/test_ci_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@


class TestCIParamsBootstrap:
def test_min_n_in_pure_python_mode(self, monkeypatch):
"""min_n raises the floor in pure Python mode."""
def test_min_n_capped_at_49_in_pure_python_mode(self, monkeypatch):
"""min_n is capped at 49 in pure Python mode."""
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
assert CIParams.bootstrap(499, min_n=199) == 199
assert CIParams.bootstrap(499, min_n=199) == 49

def test_min_n_passthrough_in_rust_mode(self, monkeypatch):
"""min_n has no effect when Rust backend is available."""
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", False)
assert CIParams.bootstrap(499, min_n=199) == 499

def test_min_n_capped_at_original_request(self, monkeypatch):
"""min_n never exceeds the original n."""
def test_min_n_cap_then_n_cap(self, monkeypatch):
"""min_n cap (49) applies, then result is min(n, effective_floor)."""
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
assert CIParams.bootstrap(100, min_n=199) == 100
# effective_min = min(199, 49) = 49; max(49, 16) = 49; min(100, 49) = 49
assert CIParams.bootstrap(100, min_n=199) == 49

def test_n_lte_10_ignores_min_n(self, monkeypatch):
"""n <= 10 always returns n regardless of min_n or mode."""
Expand All @@ -31,3 +32,14 @@ def test_default_min_n_preserves_existing_behavior(self, monkeypatch):
"""Default min_n=11 matches pre-change behavior."""
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
assert CIParams.bootstrap(499) == max(11, int(math.sqrt(499) * 1.6)) # 35

def test_min_n_cap_with_high_min_n(self, monkeypatch):
"""min_n=249 is also capped at 49 in pure Python mode."""
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
assert CIParams.bootstrap(499, min_n=249) == 49

def test_n_still_caps_result(self, monkeypatch):
"""Original n still caps the result when min_n is below cap."""
monkeypatch.setattr(conftest_module, "_PURE_PYTHON_MODE", True)
# effective_min = min(40, 49) = 40; max(40, 8) = 40; min(30, 40) = 30
assert CIParams.bootstrap(30, min_n=40) == 30
11 changes: 7 additions & 4 deletions tests/test_methodology_callaway.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,11 +803,12 @@ class TestSEFormulas:
@pytest.mark.slow
def test_analytical_se_close_to_bootstrap_se(self, ci_params):
"""
Analytical and bootstrap SEs should be within 20%.
Analytical and bootstrap SEs should be within 25%.

Analytical SEs use influence function aggregation.
Bootstrap SEs use multiplier bootstrap.
They should converge for large samples.
They should converge for large samples. Wider tolerance (40%)
when min_n cap reduces bootstrap iterations in pure Python mode.

This test is marked slow because it uses 499 bootstrap iterations
for thorough validation of SE convergence.
Expand All @@ -833,10 +834,12 @@ def test_analytical_se_close_to_bootstrap_se(self, ci_params):
time='period', first_treat='first_treat'
)

# Check overall ATT SE
# Check overall ATT SE (wider tolerance when min_n cap reduces
# bootstrap iterations in pure Python mode)
if results_boot.overall_se > 0:
rel_diff = abs(results_anal.overall_se - results_boot.overall_se) / results_boot.overall_se
assert rel_diff < 0.25, \
threshold = 0.40 if n_boot < 100 else 0.25
assert rel_diff < threshold, \
f"Analytical SE ({results_anal.overall_se}) differs from bootstrap SE " \
f"({results_boot.overall_se}) by {rel_diff*100:.1f}%"

Expand Down
10 changes: 5 additions & 5 deletions tests/test_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def test_simulation_with_large_effect(self):
n_periods=4,
treatment_effect=10.0, # Very large effect
sigma=1.0, # Low noise
n_simulations=50,
n_simulations=30,
seed=42,
progress=False,
)
Expand All @@ -416,7 +416,7 @@ def test_simulation_with_zero_effect(self):
n_periods=4,
treatment_effect=0.0, # No effect
sigma=1.0,
n_simulations=50,
n_simulations=30,
seed=42,
progress=False,
)
Expand Down Expand Up @@ -459,7 +459,7 @@ def test_simulation_coverage(self):
n_periods=4,
treatment_effect=5.0,
sigma=2.0,
n_simulations=100,
n_simulations=50,
seed=42,
progress=False,
)
Expand All @@ -476,7 +476,7 @@ def test_simulation_bias(self):
n_periods=4,
treatment_effect=5.0,
sigma=1.0,
n_simulations=100,
n_simulations=50,
seed=42,
progress=False,
)
Expand Down Expand Up @@ -524,7 +524,7 @@ def test_simulation_confidence_interval(self):
did = DifferenceInDifferences()
results = simulate_power(
estimator=did,
n_simulations=100,
n_simulations=50,
seed=42,
progress=False,
)
Expand Down
14 changes: 8 additions & 6 deletions tests/test_staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,12 +1579,13 @@ def test_analytical_se_vs_bootstrap_se(self, ci_params):
# Point estimates should match exactly
assert abs(results_analytical.overall_att - results_bootstrap.overall_att) < 1e-10

# SEs should be similar (within 15%)
# Note: Some difference expected due to bootstrap variance vs asymptotic variance
# SEs should be similar (within 15% with enough bootstrap iterations,
# wider tolerance when min_n cap reduces iterations in pure Python mode)
rel_diff = abs(
results_analytical.overall_se - results_bootstrap.overall_se
) / results_bootstrap.overall_se
assert rel_diff < 0.15, (
threshold = 0.40 if n_boot < 100 else 0.15
assert rel_diff < threshold, (
f"Analytical SE ({results_analytical.overall_se:.4f}) differs from "
f"bootstrap SE ({results_bootstrap.overall_se:.4f}) by {rel_diff:.1%}"
)
Expand Down Expand Up @@ -1726,16 +1727,17 @@ def test_event_study_analytical_se(self, ci_params):
assert results_analytical.event_study_effects is not None
assert results_bootstrap.event_study_effects is not None

# Check each event time SE is similar
# Check each event time SE is similar (wider tolerance when
# min_n cap reduces bootstrap iterations in pure Python mode)
threshold = 0.40 if n_boot < 100 else 0.20
for e in results_analytical.event_study_effects:
if e in results_bootstrap.event_study_effects:
se_analytical = results_analytical.event_study_effects[e]['se']
se_bootstrap = results_bootstrap.event_study_effects[e]['se']

if se_bootstrap > 0:
rel_diff = abs(se_analytical - se_bootstrap) / se_bootstrap
# Allow 20% difference for event study (more variance)
assert rel_diff < 0.20, (
assert rel_diff < threshold, (
f"Event study SE at e={e}: analytical={se_analytical:.4f}, "
f"bootstrap={se_bootstrap:.4f}, diff={rel_diff:.1%}"
)
Expand Down
Loading