From 8cad22231576e29548156e0fe2b3fc0feb3b120a Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 25 Feb 2026 10:21:24 -0500 Subject: [PATCH] Upgrade SQLAlchemy from v1 to v2 for Python 3.14 compatibility SQLAlchemy v1 does not support Python 3.14, making this upgrade a blocker. Key changes: - Update setup.py pin from sqlalchemy>=1.4,<2 to sqlalchemy>=2,<3 - Replace removed engine.execute() with connection-based execution using conn.exec_driver_sql() inside a connection context manager - Add _ResultProxy wrapper to eagerly fetch results so they survive connection closure (maintains existing fetchone()/fetchall() API) - Replace LegacyRow (removed in v2) with Row in type annotations - Update test mocks that used spec=LegacyRow to use plain MagicMock() since Row in v2 has a different interface - Add 10 dedicated SQLAlchemy v2 compatibility tests Co-Authored-By: Claude Opus 4.6 --- changelog_entry.yaml | 4 + policyengine_api/data/data.py | 48 ++++- .../services/household_service.py | 4 +- policyengine_api/services/policy_service.py | 4 +- .../services/report_output_service.py | 4 +- .../services/simulation_service.py | 4 +- setup.py | 2 +- .../python/test_household_routes.py | 10 +- tests/unit/data/test_sqlalchemy_v2.py | 194 ++++++++++++++++++ tests/unit/services/test_household_service.py | 1 - 10 files changed, 258 insertions(+), 17 deletions(-) create mode 100644 tests/unit/data/test_sqlalchemy_v2.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..d52575c02 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Upgraded SQLAlchemy from v1 (>=1.4,<2) to v2 (>=2,<3) for Python 3.14 compatibility. Replaced removed engine.execute() with connection-based execution, updated LegacyRow to Row, and added _ResultProxy wrapper for eager result fetching. diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index c64ffd065..2318fc43a 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -13,6 +13,34 @@ load_dotenv() +class _ResultProxy: + """Lightweight wrapper that eagerly fetches results from a + SQLAlchemy CursorResult so they survive connection closure. + Provides fetchone()/fetchall() with dict-like row access.""" + + def __init__(self, cursor_result): + try: + # Use .mappings() so rows behave like dicts + self._rows = list(cursor_result.mappings()) + except Exception: + # For non-SELECT statements (INSERT/UPDATE/DELETE) + # there are no rows to fetch + self._rows = [] + self._index = 0 + + def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + def fetchall(self): + remaining = self._rows[self._index :] + self._index = len(self._rows) + return remaining + + class PolicyEngineDatabase: """ A wrapper around the database connection. @@ -70,6 +98,22 @@ def _close_pool(self): except: pass + def _execute_remote(self, query_args): + """Execute a query against the remote database using + SQLAlchemy v2 connection-based execution.""" + main_query = query_args[0] + params = query_args[1] if len(query_args) > 1 else None + with self.pool.connect() as conn: + if params is not None: + result = conn.exec_driver_sql(main_query, params) + else: + result = conn.exec_driver_sql(main_query) + conn.commit() + # Return a lightweight wrapper that holds + # the fetched results so they survive the + # connection context closing + return _ResultProxy(result) + def query(self, *query): if self.local: with sqlite3.connect(self.db_url) as conn: @@ -89,7 +133,7 @@ def dict_factory(cursor, row): main_query = main_query.replace("?", "%s") query[0] = main_query try: - return self.pool.execute(*query) + return self._execute_remote(query) # Except InterfaceError and OperationalError, which are thrown when the connection is lost. except ( sqlalchemy.exc.InterfaceError, @@ -98,7 +142,7 @@ def dict_factory(cursor, row): try: self._close_pool() self._create_pool() - return self.pool.execute(*query) + return self._execute_remote(query) except Exception as e: raise e diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index 4091f71d9..8b3e658f4 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -1,5 +1,5 @@ import json -from sqlalchemy.engine.row import LegacyRow +from sqlalchemy.engine.row import Row from policyengine_api.data import database from policyengine_api.utils import hash_object @@ -24,7 +24,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None: f"Invalid household ID: {household_id}. Must be a positive integer." ) - row: LegacyRow | None = database.query( + row: Row | None = database.query( f"SELECT * FROM household WHERE id = ? AND country_id = ?", (household_id, country_id), ).fetchone() diff --git a/policyengine_api/services/policy_service.py b/policyengine_api/services/policy_service.py index e89f9fa87..bc63bc34c 100644 --- a/policyengine_api/services/policy_service.py +++ b/policyengine_api/services/policy_service.py @@ -1,5 +1,5 @@ import json -from sqlalchemy.engine.row import LegacyRow +from sqlalchemy.engine.row import Row from policyengine_api.data import database from policyengine_api.utils import hash_object @@ -37,7 +37,7 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None: raise ValueError("country_id cannot be empty or None") # If no policy found, this will return None - row: LegacyRow | None = database.query( + row: Row | None = database.query( "SELECT * FROM policy WHERE country_id = ? AND id = ?", (country_id, policy_id), ).fetchone() diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 4793ae018..c1d1765fb 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -1,4 +1,4 @@ -from sqlalchemy.engine.row import LegacyRow +from sqlalchemy.engine.row import Row from policyengine_api.data import database from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS @@ -137,7 +137,7 @@ def get_report_output(self, report_output_id: int) -> dict | None: f"Invalid report output ID: {report_output_id}. Must be a positive integer." ) - row: LegacyRow | None = database.query( + row: Row | None = database.query( "SELECT * FROM report_outputs WHERE id = ?", (report_output_id,), ).fetchone() diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 88f359ae7..bb5b5d290 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -1,5 +1,5 @@ import json -from sqlalchemy.engine.row import LegacyRow +from sqlalchemy.engine.row import Row from policyengine_api.data import database from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS @@ -119,7 +119,7 @@ def get_simulation( f"Invalid simulation ID: {simulation_id}. Must be a positive integer." ) - row: LegacyRow | None = database.query( + row: Row | None = database.query( "SELECT * FROM simulations WHERE id = ? AND country_id = ?", (simulation_id, country_id), ).fetchone() diff --git a/setup.py b/setup.py index afcb5ee8a..beb53f905 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ "python-dotenv", "redis", "rq", - "sqlalchemy>=1.4,<2", + "sqlalchemy>=2,<3", "streamlit", "werkzeug", "Flask-Caching>=2,<3", diff --git a/tests/to_refactor/python/test_household_routes.py b/tests/to_refactor/python/test_household_routes.py index e4ea05a1c..78d766fb8 100644 --- a/tests/to_refactor/python/test_household_routes.py +++ b/tests/to_refactor/python/test_household_routes.py @@ -1,7 +1,6 @@ import pytest import json from unittest.mock import MagicMock, patch -from sqlalchemy.engine.row import LegacyRow from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS @@ -16,8 +15,9 @@ class TestGetHousehold: def test_get_existing_household(self, rest_client, mock_database): """Test getting an existing household.""" - # Mock database response - mock_row = MagicMock(spec=LegacyRow) + # Mock database response as a dict-like object + # (SQLAlchemy v2 Row objects support dict() via ._mapping) + mock_row = MagicMock() mock_row.__getitem__.side_effect = lambda x: valid_db_row[x] mock_row.keys.return_value = valid_db_row.keys() mock_database.query().fetchone.return_value = mock_row @@ -57,7 +57,7 @@ def test_create_household_success( ): """Test successfully creating a new household.""" # Mock database responses - mock_row = MagicMock(spec=LegacyRow) + mock_row = MagicMock() mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x] mock_database.query().fetchone.return_value = mock_row @@ -111,7 +111,7 @@ def test_update_household_success( ): """Test successfully updating an existing household.""" # Mock getting existing household - mock_row = MagicMock(spec=LegacyRow) + mock_row = MagicMock() mock_row.__getitem__.side_effect = lambda x: valid_db_row[x] mock_row.keys.return_value = valid_db_row.keys() mock_database.query().fetchone.return_value = mock_row diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py new file mode 100644 index 000000000..fceb898e4 --- /dev/null +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -0,0 +1,194 @@ +"""Tests for SQLAlchemy v2 compatibility. + +These tests verify that the database layer works correctly with +SQLAlchemy v2, specifically: +- The _ResultProxy wrapper provides fetchone()/fetchall() on eagerly + fetched results. +- The remote (non-local) query path uses connection-based execution + instead of the removed engine.execute(). +- Row objects returned from the remote path support dict-like access + (dict(row) and row["key"]). +""" + +import pytest +import sqlalchemy + +from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase + + +class TestSQLAlchemyVersion: + """Verify that SQLAlchemy v2 is installed.""" + + def test_sqlalchemy_version_is_v2(self): + major = int(sqlalchemy.__version__.split(".")[0]) + assert ( + major >= 2 + ), f"Expected SQLAlchemy v2+, got {sqlalchemy.__version__}" + + +class TestResultProxy: + """Test the _ResultProxy wrapper that bridges SQLAlchemy v2 + connection-scoped results with the existing query() API.""" + + def test_fetchone_returns_dict_like_rows(self): + """Rows returned by fetchone() should support dict() and + key-based access.""" + engine = sqlalchemy.create_engine("sqlite://") + with engine.connect() as conn: + conn.exec_driver_sql( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + ) + conn.exec_driver_sql("INSERT INTO test VALUES (1, 'hello')") + result = conn.exec_driver_sql("SELECT * FROM test") + proxy = _ResultProxy(result) + + row = proxy.fetchone() + assert row is not None + assert dict(row) == {"id": 1, "name": "hello"} + assert row["id"] == 1 + assert row["name"] == "hello" + + def test_fetchone_returns_none_when_exhausted(self): + engine = sqlalchemy.create_engine("sqlite://") + with engine.connect() as conn: + conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)") + result = conn.exec_driver_sql("SELECT * FROM test") + proxy = _ResultProxy(result) + + assert proxy.fetchone() is None + + def test_fetchall_returns_all_rows(self): + engine = sqlalchemy.create_engine("sqlite://") + with engine.connect() as conn: + conn.exec_driver_sql( + "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)" + ) + conn.exec_driver_sql("INSERT INTO test VALUES (1, 'a')") + conn.exec_driver_sql("INSERT INTO test VALUES (2, 'b')") + conn.exec_driver_sql("INSERT INTO test VALUES (3, 'c')") + result = conn.exec_driver_sql("SELECT * FROM test") + proxy = _ResultProxy(result) + + rows = proxy.fetchall() + assert len(rows) == 3 + assert dict(rows[0]) == {"id": 1, "val": "a"} + assert dict(rows[2]) == {"id": 3, "val": "c"} + + def test_fetchone_then_fetchall_respects_cursor_position(self): + engine = sqlalchemy.create_engine("sqlite://") + with engine.connect() as conn: + conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)") + conn.exec_driver_sql("INSERT INTO test VALUES (1)") + conn.exec_driver_sql("INSERT INTO test VALUES (2)") + conn.exec_driver_sql("INSERT INTO test VALUES (3)") + result = conn.exec_driver_sql("SELECT * FROM test") + proxy = _ResultProxy(result) + + first = proxy.fetchone() + assert dict(first) == {"id": 1} + remaining = proxy.fetchall() + assert len(remaining) == 2 + assert dict(remaining[0]) == {"id": 2} + + def test_result_proxy_for_insert_statement(self): + """INSERT statements produce no rows; _ResultProxy should + handle this gracefully.""" + engine = sqlalchemy.create_engine("sqlite://") + with engine.connect() as conn: + conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)") + result = conn.exec_driver_sql("INSERT INTO test VALUES (1)") + proxy = _ResultProxy(result) + + assert proxy.fetchone() is None + assert proxy.fetchall() == [] + + +class TestRemoteQueryPath: + """Test the non-local query path that uses SQLAlchemy engine + with connection-based execution (v2 pattern).""" + + def _make_remote_db(self): + """Create a PolicyEngineDatabase-like object that uses + a SQLAlchemy engine (the 'remote' path) but backed by + in-memory SQLite for testing.""" + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db.local = False + db.pool = sqlalchemy.create_engine("sqlite://") + # Initialize schema using the remote path + with db.pool.connect() as conn: + conn.exec_driver_sql( + "CREATE TABLE test_table " + "(id INTEGER PRIMARY KEY, name TEXT, value REAL)" + ) + conn.commit() + return db + + def test_remote_insert_and_select(self): + """Test INSERT then SELECT through the remote query path.""" + db = self._make_remote_db() + + # Note: remote path converts ? to %s for MySQL, but SQLite + # uses ? natively. Since exec_driver_sql passes to the DBAPI + # driver directly and SQLite's driver uses ?, we need to + # test with the actual query() method which does the conversion. + # For SQLite DBAPI, ? is the native marker. + + # Use exec_driver_sql directly to bypass ?->%s conversion + # (which would break SQLite) + db._execute_remote( + [ + "INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)", + (1, "test", 3.14), + ] + ) + + result = db._execute_remote( + ["SELECT * FROM test_table WHERE id = ?", (1,)] + ) + row = result.fetchone() + assert row is not None + assert row["id"] == 1 + assert row["name"] == "test" + assert row["value"] == 3.14 + assert dict(row) == {"id": 1, "name": "test", "value": 3.14} + + def test_remote_select_no_results(self): + db = self._make_remote_db() + result = db._execute_remote( + ["SELECT * FROM test_table WHERE id = ?", (999,)] + ) + assert result.fetchone() is None + + def test_remote_update(self): + db = self._make_remote_db() + db._execute_remote( + [ + "INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)", + (1, "original", 1.0), + ] + ) + db._execute_remote( + [ + "UPDATE test_table SET name = ? WHERE id = ?", + ("updated", 1), + ] + ) + result = db._execute_remote( + ["SELECT * FROM test_table WHERE id = ?", (1,)] + ) + row = result.fetchone() + assert row["name"] == "updated" + + def test_remote_delete(self): + db = self._make_remote_db() + db._execute_remote( + [ + "INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)", + (1, "to_delete", 0.0), + ] + ) + db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) + result = db._execute_remote( + ["SELECT * FROM test_table WHERE id = ?", (1,)] + ) + assert result.fetchone() is None diff --git a/tests/unit/services/test_household_service.py b/tests/unit/services/test_household_service.py index 9a3ccad6d..3c0ee9b50 100644 --- a/tests/unit/services/test_household_service.py +++ b/tests/unit/services/test_household_service.py @@ -1,7 +1,6 @@ import pytest import json from unittest.mock import MagicMock -from sqlalchemy.engine.row import LegacyRow import re from policyengine_api.services.household_service import HouseholdService