diff --git a/.gitignore b/.gitignore index b508df9..cff8e75 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,7 @@ __pycache__/ *.pyc -.pytest_cache/ \ No newline at end of file +.pytest_cache/ + +# Streamlit secrets +.streamlit/secrets.toml \ No newline at end of file diff --git a/load_data_to_bq.py b/load_data_to_bq.py new file mode 100644 index 0000000..d9284fe --- /dev/null +++ b/load_data_to_bq.py @@ -0,0 +1,67 @@ +"""Load MTA ridership data from NYC Open Data API into BigQuery.""" + +import sys + +import pandas as pd +import pydata_google_auth + +import pandas_gbq + +PROJECT_ID = "sipa-adv-c-bouncing-penguin" +DATASET_TABLE = "mta_data.daily_ridership" + +SCOPES = [ + "https://www.googleapis.com/auth/bigquery", +] + + +def get_credentials(): + """Get Google credentials with browser-based auth flow.""" + print("Authenticating with Google... A browser window should open.") + print("If it doesn't, copy the URL shown below and open it manually.") + credentials = pydata_google_auth.get_user_credentials( + SCOPES, + auth_local_webserver=False, + ) + print("Authentication successful!") + return credentials + + +def fetch_mta_data() -> pd.DataFrame: + """Pull MTA ridership data from NYC Open Data API.""" + print("Fetching MTA data from NYC Open Data API...") + sys.stdout.flush() + url = "https://data.ny.gov/resource/vxuj-8kew.csv?$limit=50000" + df = pd.read_csv(url) + df["date"] = pd.to_datetime(df["date"]) + print(f"Fetched {len(df)} rows (from {df['date'].min().date()} to {df['date'].max().date()})") + return df + + +def main(): + # Step 1: Authenticate + credentials = get_credentials() + + # Step 2: Fetch data + df = fetch_mta_data() + + # Step 3: Upload to BigQuery + print(f"Uploading to BigQuery: {PROJECT_ID}.{DATASET_TABLE} ...") + sys.stdout.flush() + pandas_gbq.to_gbq( + df, + destination_table=DATASET_TABLE, + project_id=PROJECT_ID, + if_exists="replace", + credentials=credentials, + ) + print("Done! Data loaded to BigQuery successfully.") + + # Step 4: Verify + query = f"SELECT COUNT(*) as row_count FROM `{PROJECT_ID}.{DATASET_TABLE}`" + result = pandas_gbq.read_gbq(query, project_id=PROJECT_ID, credentials=credentials) + print(f"Verification: {result['row_count'].iloc[0]} rows in BigQuery table.") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index a91dd02..b106761 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,7 @@ requests ruff pytest matplotlib -pandera \ No newline at end of file +pandera +pandas-gbq +google-cloud-bigquery +db-dtypes \ No newline at end of file diff --git a/streamlit_app.py b/streamlit_app.py index bfc5a1c..58e52bd 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -276,14 +276,23 @@ def fetch_data(): dates = sel_holidays[sel_holidays["holiday"] == holiday]["date"] color = colors[i % len(colors)] for j, d in enumerate(dates): + d = pd.Timestamp(d) if filtered["date"].min() <= d <= filtered["date"].max(): - fig_holiday.add_vline( - x=d, - line_dash="dot", - line_color=color, - annotation_text=holiday if j == 0 else None, - annotation_position="top left", + d_str = d.strftime("%Y-%m-%d") + fig_holiday.add_shape( + type="line", + x0=d_str, x1=d_str, + y0=0, y1=1, + yref="paper", + line=dict(dash="dot", color=color), ) + if j == 0: + fig_holiday.add_annotation( + x=d_str, y=1, yref="paper", + text=holiday, + showarrow=False, + xanchor="left", + ) fig_holiday.update_layout( yaxis_title="Subway Recovery (% of Pre-Pandemic)", @@ -298,7 +307,7 @@ def fetch_data(): st.markdown("**Average Subway Ridership Around Holidays**") impact_rows = [] for _, row in sel_holidays.iterrows(): - h_date = row["date"] + h_date = pd.Timestamp(row["date"]) # 3-day window around the holiday window = filtered[ (filtered["date"] >= h_date - pd.Timedelta(days=1)) diff --git a/utils.py b/utils.py index 5e7a5d6..e7e6026 100644 --- a/utils.py +++ b/utils.py @@ -1,11 +1,27 @@ import matplotlib.pyplot as plt import pandas as pd +from google.cloud import bigquery +from google.oauth2 import service_account + +PROJECT_ID = "sipa-adv-c-bouncing-penguin" +DATASET_TABLE = "mta_data.daily_ridership" def load_mta_data() -> pd.DataFrame: - """Load MTA ridership data from NYC Open Data API.""" - url = "https://data.ny.gov/resource/vxuj-8kew.csv?$limit=50000" - df = pd.read_csv(url) + """Load MTA ridership data from BigQuery.""" + try: + import streamlit as st + + credentials = service_account.Credentials.from_service_account_info( + st.secrets["gcp_service_account"] + ) + client = bigquery.Client(credentials=credentials, project=PROJECT_ID) + except Exception: + # Fallback: use default credentials (e.g. local gcloud auth) + client = bigquery.Client(project=PROJECT_ID) + + query = f"SELECT * FROM `{PROJECT_ID}.{DATASET_TABLE}`" + df = client.query(query).to_dataframe() df = clean_mta_df(df) return df