diff --git a/requirements.txt b/requirements.txt index e35eafc..a91dd02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ plotly requests ruff pytest -matplotlib \ No newline at end of file +matplotlib +pandera \ No newline at end of file diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..1dba220 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,89 @@ +import pandas as pd +import pandera as pa +import pytest + +from validation import validate_mta_data + + +def test_valid_data(): + """Test that valid data passes validation.""" + df = pd.DataFrame( + { + "date": ["2020-03-01", "2020-03-02"], + "subways_total_estimated_ridership": [1000000.0, 1100000.0], + "subways_of_comparable_pre_pandemic_day": [0.5, 0.6], + "buses_total_estimated_ridership": [500000.0, 550000.0], + "buses_of_comparable_pre_pandemic_day": [0.6, 0.65], + "lirr_total_estimated_ridership": [100000.0, 110000.0], + "lirr_of_comparable_pre_pandemic_day": [0.4, 0.45], + "metro_north_total_estimated_ridership": [80000.0, 85000.0], + "metro_north_of_comparable_pre_pandemic_day": [0.35, 0.4], + "bridges_and_tunnels_total_traffic": [700000.0, 720000.0], + "bridges_and_tunnels_of_comparable_pre_pandemic_day": [0.9, 0.92], + } + ) + result = validate_mta_data(df) + assert len(result) == 2 + + +def test_negative_ridership_fails(): + """Test that negative ridership values fail validation.""" + df = pd.DataFrame( + { + "date": ["2020-03-01"], + "subways_total_estimated_ridership": [-100.0], + "subways_of_comparable_pre_pandemic_day": [0.5], + "buses_total_estimated_ridership": [500000.0], + "buses_of_comparable_pre_pandemic_day": [0.6], + "lirr_total_estimated_ridership": [100000.0], + "lirr_of_comparable_pre_pandemic_day": [0.4], + "metro_north_total_estimated_ridership": [80000.0], + "metro_north_of_comparable_pre_pandemic_day": [0.35], + "bridges_and_tunnels_total_traffic": [700000.0], + "bridges_and_tunnels_of_comparable_pre_pandemic_day": [0.9], + } + ) + with pytest.raises(pa.errors.SchemaError): + validate_mta_data(df) + + +def test_ratio_exceeds_max_fails(): + """Test that pre-pandemic ratio > 2.0 fails validation.""" + df = pd.DataFrame( + { + "date": ["2020-03-01"], + "subways_total_estimated_ridership": [1000000.0], + "subways_of_comparable_pre_pandemic_day": [3.0], + "buses_total_estimated_ridership": [500000.0], + "buses_of_comparable_pre_pandemic_day": [0.6], + "lirr_total_estimated_ridership": [100000.0], + "lirr_of_comparable_pre_pandemic_day": [0.4], + "metro_north_total_estimated_ridership": [80000.0], + "metro_north_of_comparable_pre_pandemic_day": [0.35], + "bridges_and_tunnels_total_traffic": [700000.0], + "bridges_and_tunnels_of_comparable_pre_pandemic_day": [0.9], + } + ) + with pytest.raises(pa.errors.SchemaError): + validate_mta_data(df) + + +def test_missing_date_fails(): + """Test that null dates fail validation.""" + df = pd.DataFrame( + { + "date": [None], + "subways_total_estimated_ridership": [1000000.0], + "subways_of_comparable_pre_pandemic_day": [0.5], + "buses_total_estimated_ridership": [500000.0], + "buses_of_comparable_pre_pandemic_day": [0.6], + "lirr_total_estimated_ridership": [100000.0], + "lirr_of_comparable_pre_pandemic_day": [0.4], + "metro_north_total_estimated_ridership": [80000.0], + "metro_north_of_comparable_pre_pandemic_day": [0.35], + "bridges_and_tunnels_total_traffic": [700000.0], + "bridges_and_tunnels_of_comparable_pre_pandemic_day": [0.9], + } + ) + with pytest.raises(pa.errors.SchemaError): + validate_mta_data(df) diff --git a/validation.py b/validation.py new file mode 100644 index 0000000..90cb567 --- /dev/null +++ b/validation.py @@ -0,0 +1,94 @@ +import pandera as pa + +# Schema for MTA Daily Ridership Data +mta_schema = pa.DataFrameSchema( + { + "date": pa.Column( + pa.DateTime, + nullable=False, + checks=pa.Check.greater_than_or_equal_to("2020-03-01"), + description="Date of ridership record, starting from March 2020", + ), + "subways_total_estimated_ridership": pa.Column( + float, + nullable=True, + checks=pa.Check.greater_than_or_equal_to(0), + description="Total estimated subway ridership", + ), + "subways_of_comparable_pre_pandemic_day": pa.Column( + float, + nullable=True, + checks=[ + pa.Check.greater_than_or_equal_to(0), + pa.Check.less_than_or_equal_to(2.0), + ], + description="Subway ridership as ratio of pre-pandemic levels (0 to 2.0)", + ), + "buses_total_estimated_ridership": pa.Column( + float, + nullable=True, + checks=pa.Check.greater_than_or_equal_to(0), + description="Total estimated bus ridership", + ), + "buses_of_comparable_pre_pandemic_day": pa.Column( + float, + nullable=True, + checks=[ + pa.Check.greater_than_or_equal_to(0), + pa.Check.less_than_or_equal_to(2.0), + ], + description="Bus ridership as ratio of pre-pandemic levels", + ), + "lirr_total_estimated_ridership": pa.Column( + float, + nullable=True, + checks=pa.Check.greater_than_or_equal_to(0), + description="Total estimated LIRR ridership", + ), + "lirr_of_comparable_pre_pandemic_day": pa.Column( + float, + nullable=True, + checks=[ + pa.Check.greater_than_or_equal_to(0), + pa.Check.less_than_or_equal_to(2.0), + ], + description="LIRR ridership as ratio of pre-pandemic levels", + ), + "metro_north_total_estimated_ridership": pa.Column( + float, + nullable=True, + checks=pa.Check.greater_than_or_equal_to(0), + description="Total estimated Metro-North ridership", + ), + "metro_north_of_comparable_pre_pandemic_day": pa.Column( + float, + nullable=True, + checks=[ + pa.Check.greater_than_or_equal_to(0), + pa.Check.less_than_or_equal_to(2.0), + ], + description="Metro-North ridership as ratio of pre-pandemic levels", + ), + "bridges_and_tunnels_total_traffic": pa.Column( + float, + nullable=True, + checks=pa.Check.greater_than_or_equal_to(0), + description="Total bridges and tunnels traffic", + ), + "bridges_and_tunnels_of_comparable_pre_pandemic_day": pa.Column( + float, + nullable=True, + checks=[ + pa.Check.greater_than_or_equal_to(0), + pa.Check.less_than_or_equal_to(2.0), + ], + description="B&T traffic as ratio of pre-pandemic levels", + ), + }, + coerce=True, +) + + +def validate_mta_data(df): + """Validate MTA ridership dataframe against schema.""" + return mta_schema.validate(df)