Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/add-database-build-test.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add end-to-end test for calibration database build pipeline.
196 changes: 196 additions & 0 deletions policyengine_us_data/tests/test_database_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""
End-to-end test for the calibration database build pipeline.

Runs every ETL script in the same order as ``make database`` and
validates the resulting SQLite database has the expected structure and
content. This catches API mismatches, missing imports, and data-loading
errors that unit tests on individual tables would miss.
"""

import sqlite3
import subprocess
import sys
from pathlib import Path

import pytest

from policyengine_us_data.storage import STORAGE_FOLDER

# Directory and file for the calibration database.
DB_DIR = STORAGE_FOLDER / "calibration"
DB_PATH = DB_DIR / "policy_data.db"

# HuggingFace URL for the stratified CPS dataset.
# ETL scripts use this only to derive the time period (2024).
HF_DATASET = (
"hf://policyengine/policyengine-us-data"
"/calibration/stratified_extended_cps.h5"
)

# Scripts run in the same order as `make database` in the Makefile.
# create_database_tables.py does not use etl_argparser.
PIPELINE_SCRIPTS = [
("db/create_database_tables.py", []),
("db/create_initial_strata.py", ["--dataset", HF_DATASET]),
("db/etl_national_targets.py", ["--dataset", HF_DATASET]),
("db/etl_age.py", ["--dataset", HF_DATASET]),
("db/etl_medicaid.py", ["--dataset", HF_DATASET]),
("db/etl_snap.py", ["--dataset", HF_DATASET]),
("db/etl_state_income_tax.py", ["--dataset", HF_DATASET]),
("db/etl_irs_soi.py", ["--dataset", HF_DATASET]),
("db/validate_database.py", []),
]

PKG_ROOT = Path(__file__).resolve().parent.parent # policyengine_us_data/


def _run_script(
relative_path: str,
extra_args: list,
) -> subprocess.CompletedProcess:
"""Run a script from the package root and return the result."""
script = PKG_ROOT / relative_path
assert script.exists(), f"Script not found: {script}"
return subprocess.run(
[sys.executable, str(script)] + extra_args,
capture_output=True,
text=True,
timeout=300,
)


@pytest.fixture(scope="module")
def built_db():
"""Build the calibration database from scratch once per module.

Removes any existing DB first so the test validates a clean build.
"""
DB_DIR.mkdir(parents=True, exist_ok=True)
if DB_PATH.exists():
DB_PATH.unlink()

errors = []
for script, args in PIPELINE_SCRIPTS:
result = _run_script(script, args)
if result.returncode != 0:
errors.append(
f"{script} failed (rc={result.returncode}):\n"
f" stderr (last 500 chars): "
f"{result.stderr[-500:]}"
)

if errors:
pytest.fail(
f"{len(errors)} ETL script(s) failed:\n" + "\n\n".join(errors)
)

assert DB_PATH.exists(), "policy_data.db was not created"
return DB_PATH


def test_all_etl_scripts_succeed(built_db):
"""The fixture itself asserts all scripts pass; this makes the
assertion visible as a named test."""
assert built_db.exists()


def test_expected_tables_exist(built_db):
"""Core tables must be present."""
conn = sqlite3.connect(str(built_db))
tables = {
row[0]
for row in conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
}
conn.close()

for expected in ["strata", "stratum_constraints", "targets"]:
assert expected in tables, f"Missing table: {expected}"


def test_national_targets_loaded(built_db):
"""National targets should include well-known variables."""
conn = sqlite3.connect(str(built_db))
# The national stratum has no constraints in stratum_constraints.
rows = conn.execute("""
SELECT DISTINCT t.variable
FROM targets t
JOIN strata s ON t.stratum_id = s.stratum_id
LEFT JOIN stratum_constraints sc
ON s.stratum_id = sc.stratum_id
WHERE sc.stratum_id IS NULL
""").fetchall()
conn.close()

variables = {r[0] for r in rows}
for expected in ["snap", "social_security", "ssi"]:
assert expected in variables, (
f"National target '{expected}' missing. "
f"Found: {sorted(variables)}"
)


def test_state_income_tax_targets(built_db):
"""State income tax targets should cover all income-tax states."""
conn = sqlite3.connect(str(built_db))
rows = conn.execute("""
SELECT sc.value, t.value
FROM targets t
JOIN strata s ON t.stratum_id = s.stratum_id
JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id
WHERE t.variable = 'state_income_tax'
AND sc.constraint_variable = 'state_fips'
""").fetchall()
conn.close()

state_totals = {r[0]: r[1] for r in rows}

n = len(state_totals)
assert n >= 42, f"Expected >= 42 state income tax targets, got {n}"

# California should be the largest, over $100B.
ca_val = state_totals.get("06") or state_totals.get("6")
assert ca_val is not None, "California (FIPS 06) target missing"
assert ca_val > 100e9, (
f"California income tax should be > $100B, "
f"got ${ca_val / 1e9:.1f}B"
)


def test_congressional_district_strata(built_db):
"""Should have strata for >= 435 congressional districts."""
conn = sqlite3.connect(str(built_db))
n_cds = conn.execute("""
SELECT COUNT(DISTINCT sc.value)
FROM stratum_constraints sc
WHERE sc.constraint_variable = 'congressional_district_geoid'
""").fetchone()[0]
conn.close()

assert n_cds >= 435, f"Expected >= 435 CD strata, got {n_cds}"


def test_all_target_variables_exist_in_policyengine(built_db):
"""Every target variable must be a valid policyengine-us variable."""
from policyengine_us.system import system

conn = sqlite3.connect(str(built_db))
variables = {
r[0] for r in conn.execute("SELECT DISTINCT variable FROM targets")
}
conn.close()

missing = [v for v in variables if v not in system.variables]
assert not missing, f"Target variables not in policyengine-us: {missing}"


def test_total_target_count(built_db):
"""Sanity check: should have a healthy number of targets."""
conn = sqlite3.connect(str(built_db))
count = conn.execute("SELECT COUNT(*) FROM targets").fetchone()[0]
conn.close()

# With national + age + medicaid + SNAP + state income tax + IRS SOI,
# we expect thousands of targets.
assert count > 1000, f"Expected > 1000 total targets, got {count}"