Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,14 @@ def test_clean_mta_df_already_sorted():
# ---------- Tests for plot_ridership_recovery ----------

def _make_ridership_df():
"""Helper: create a small valid ridership DataFrame for testing."""
return pd.DataFrame({
"date": pd.to_datetime(["2020-03-01", "2020-03-02", "2020-03-03"]),
"subways_of_comparable_pre_pandemic_day": [0.9, 0.5, 0.6],
"buses_of_comparable_pre_pandemic_day": [0.95, 0.6, 0.7],
"lirr_of_comparable_pre_pandemic_day": [0.85, 0.4, 0.5],
"metro_north_of_comparable_pre_pandemic_day": [0.88, 0.45, 0.55],
"subways_pct_of_comparable_pre_pandemic_day": [0.9, 0.5, 0.6],
"buses_pct_of_comparable_pre_pandemic_day": [1.05, 0.6, 0.7],
"lirr_pct_of_comparable_pre_pandemic_day": [0.8, 0.4, 0.5],
"metro_north_pct_of_comparable_pre_pandemic_day": [0.88, 0.45, 0.55],
})


def test_plot_ridership_recovery_returns_figure():
"""Test that the function returns a matplotlib Figure without error."""
df = _make_ridership_df()
Expand Down
77 changes: 60 additions & 17 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import matplotlib.pyplot as plt


def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()

Expand All @@ -12,42 +13,84 @@ def clean_mta_df(df: pd.DataFrame) -> pd.DataFrame:

return out


def plot_ridership_recovery(df: pd.DataFrame) -> plt.Figure:
"""Plot MTA ridership recovery by transit mode as % of pre-pandemic levels."""
required_cols = [
"date",
"subways_of_comparable_pre_pandemic_day",
"buses_of_comparable_pre_pandemic_day",
"lirr_of_comparable_pre_pandemic_day",
"metro_north_of_comparable_pre_pandemic_day",
"subways_pct_of_comparable_pre_pandemic_day",
"buses_pct_of_comparable_pre_pandemic_day",
"lirr_pct_of_comparable_pre_pandemic_day",
"metro_north_pct_of_comparable_pre_pandemic_day",
]

missing = [c for c in required_cols if c not in df.columns]
if missing:
raise KeyError(f"Missing required columns: {missing}")

plot_df = df.copy()
plot_df["subways_pct_of_comparable_pre_pandemic_day"] = (
plot_df["subways_pct_of_comparable_pre_pandemic_day"] * 100
)
plot_df["buses_pct_of_comparable_pre_pandemic_day"] = (
plot_df["buses_pct_of_comparable_pre_pandemic_day"] * 100
)
plot_df["lirr_pct_of_comparable_pre_pandemic_day"] = (
plot_df["lirr_pct_of_comparable_pre_pandemic_day"] * 100
)
plot_df["metro_north_pct_of_comparable_pre_pandemic_day"] = (
plot_df["metro_north_pct_of_comparable_pre_pandemic_day"] * 100
)

fig, ax = plt.subplots(figsize=(14, 7))

ax.plot(df["date"], df["subways_of_comparable_pre_pandemic_day"],
label="Subway", alpha=0.8, linewidth=1.2)
ax.plot(df["date"], df["buses_of_comparable_pre_pandemic_day"],
label="Bus", alpha=0.8, linewidth=1.2)
ax.plot(df["date"], df["lirr_of_comparable_pre_pandemic_day"],
label="LIRR", alpha=0.8, linewidth=1.2)
ax.plot(df["date"], df["metro_north_of_comparable_pre_pandemic_day"],
label="Metro-North", alpha=0.8, linewidth=1.2)
ax.plot(
plot_df["date"],
plot_df["subways_pct_of_comparable_pre_pandemic_day"],
label="Subway",
alpha=0.8,
linewidth=1.2,
)
ax.plot(
plot_df["date"],
plot_df["buses_pct_of_comparable_pre_pandemic_day"],
label="Bus",
alpha=0.8,
linewidth=1.2,
)
ax.plot(
plot_df["date"],
plot_df["lirr_pct_of_comparable_pre_pandemic_day"],
label="LIRR",
alpha=0.8,
linewidth=1.2,
)
ax.plot(
plot_df["date"],
plot_df["metro_north_pct_of_comparable_pre_pandemic_day"],
label="Metro-North",
alpha=0.8,
linewidth=1.2,
)

ax.axhline(y=1.0, color="gray", linestyle="--", linewidth=1.5,
label="Pre-pandemic baseline (100%)")
ax.axhline(
y=100,
color="gray",
linestyle="--",
linewidth=1.5,
label="Pre-pandemic baseline (100%)",
)

ax.set_xlabel("Date", fontsize=12)
ax.set_ylabel("% of Pre-Pandemic Ridership", fontsize=12)
ax.set_title(
"MTA Ridership Recovery: Subway vs Bus vs Commuter Rail (2020-Present)",
fontsize=14, fontweight="bold",
fontsize=14,
fontweight="bold",
)
ax.legend(loc="lower right", fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.5)
ax.set_ylim(0, 150)
fig.tight_layout()

return fig
return fig
Loading