diff --git a/tests/test_utils.py b/tests/test_utils.py index 6ffadac..a3b67c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -53,16 +53,14 @@ def test_clean_mta_df_already_sorted(): # ---------- Tests for plot_ridership_recovery ---------- def _make_ridership_df(): - """Helper: create a small valid ridership DataFrame for testing.""" return pd.DataFrame({ "date": pd.to_datetime(["2020-03-01", "2020-03-02", "2020-03-03"]), - "subways_of_comparable_pre_pandemic_day": [0.9, 0.5, 0.6], - "buses_of_comparable_pre_pandemic_day": [0.95, 0.6, 0.7], - "lirr_of_comparable_pre_pandemic_day": [0.85, 0.4, 0.5], - "metro_north_of_comparable_pre_pandemic_day": [0.88, 0.45, 0.55], + "subways_pct_of_comparable_pre_pandemic_day": [0.9, 0.5, 0.6], + "buses_pct_of_comparable_pre_pandemic_day": [1.05, 0.6, 0.7], + "lirr_pct_of_comparable_pre_pandemic_day": [0.8, 0.4, 0.5], + "metro_north_pct_of_comparable_pre_pandemic_day": [0.88, 0.45, 0.55], }) - def test_plot_ridership_recovery_returns_figure(): """Test that the function returns a matplotlib Figure without error.""" df = _make_ridership_df() diff --git a/utils.py b/utils.py index b7c698d..75a4eff 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,7 @@ import pandas as pd import matplotlib.pyplot as plt + def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() @@ -12,42 +13,84 @@ def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame: return out + def plot_ridership_recovery(df: pd.DataFrame) -> plt.Figure: """Plot MTA ridership recovery by transit mode as % of pre-pandemic levels.""" required_cols = [ "date", - "subways_of_comparable_pre_pandemic_day", - "buses_of_comparable_pre_pandemic_day", - "lirr_of_comparable_pre_pandemic_day", - "metro_north_of_comparable_pre_pandemic_day", + "subways_pct_of_comparable_pre_pandemic_day", + "buses_pct_of_comparable_pre_pandemic_day", + "lirr_pct_of_comparable_pre_pandemic_day", + "metro_north_pct_of_comparable_pre_pandemic_day", ] + missing = [c for c in required_cols if c not in df.columns] if missing: raise KeyError(f"Missing required columns: {missing}") + plot_df = df.copy() + plot_df["subways_pct_of_comparable_pre_pandemic_day"] = ( + plot_df["subways_pct_of_comparable_pre_pandemic_day"] * 100 + ) + plot_df["buses_pct_of_comparable_pre_pandemic_day"] = ( + plot_df["buses_pct_of_comparable_pre_pandemic_day"] * 100 + ) + plot_df["lirr_pct_of_comparable_pre_pandemic_day"] = ( + plot_df["lirr_pct_of_comparable_pre_pandemic_day"] * 100 + ) + plot_df["metro_north_pct_of_comparable_pre_pandemic_day"] = ( + plot_df["metro_north_pct_of_comparable_pre_pandemic_day"] * 100 + ) + fig, ax = plt.subplots(figsize=(14, 7)) - ax.plot(df["date"], df["subways_of_comparable_pre_pandemic_day"], - label="Subway", alpha=0.8, linewidth=1.2) - ax.plot(df["date"], df["buses_of_comparable_pre_pandemic_day"], - label="Bus", alpha=0.8, linewidth=1.2) - ax.plot(df["date"], df["lirr_of_comparable_pre_pandemic_day"], - label="LIRR", alpha=0.8, linewidth=1.2) - ax.plot(df["date"], df["metro_north_of_comparable_pre_pandemic_day"], - label="Metro-North", alpha=0.8, linewidth=1.2) + ax.plot( + plot_df["date"], + plot_df["subways_pct_of_comparable_pre_pandemic_day"], + label="Subway", + alpha=0.8, + linewidth=1.2, + ) + ax.plot( + plot_df["date"], + plot_df["buses_pct_of_comparable_pre_pandemic_day"], + label="Bus", + alpha=0.8, + linewidth=1.2, + ) + ax.plot( + plot_df["date"], + plot_df["lirr_pct_of_comparable_pre_pandemic_day"], + label="LIRR", + alpha=0.8, + linewidth=1.2, + ) + ax.plot( + plot_df["date"], + plot_df["metro_north_pct_of_comparable_pre_pandemic_day"], + label="Metro-North", + alpha=0.8, + linewidth=1.2, + ) - ax.axhline(y=1.0, color="gray", linestyle="--", linewidth=1.5, - label="Pre-pandemic baseline (100%)") + ax.axhline( + y=100, + color="gray", + linestyle="--", + linewidth=1.5, + label="Pre-pandemic baseline (100%)", + ) ax.set_xlabel("Date", fontsize=12) ax.set_ylabel("% of Pre-Pandemic Ridership", fontsize=12) ax.set_title( "MTA Ridership Recovery: Subway vs Bus vs Commuter Rail (2020-Present)", - fontsize=14, fontweight="bold", + fontsize=14, + fontweight="bold", ) ax.legend(loc="lower right", fontsize=10) ax.grid(True, alpha=0.3) - ax.set_ylim(0, 1.5) + ax.set_ylim(0, 150) fig.tight_layout() - return fig + return fig \ No newline at end of file