Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Upgraded SQLAlchemy from v1 (>=1.4,<2) to v2 (>=2,<3) for Python 3.14 compatibility. Replaced removed engine.execute() with connection-based execution, updated LegacyRow to Row, and added _ResultProxy wrapper for eager result fetching.
48 changes: 46 additions & 2 deletions policyengine_api/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@
load_dotenv()


class _ResultProxy:
"""Lightweight wrapper that eagerly fetches results from a
SQLAlchemy CursorResult so they survive connection closure.
Provides fetchone()/fetchall() with dict-like row access."""

def __init__(self, cursor_result):
try:
# Use .mappings() so rows behave like dicts
self._rows = list(cursor_result.mappings())
except Exception:
# For non-SELECT statements (INSERT/UPDATE/DELETE)
# there are no rows to fetch
self._rows = []
self._index = 0

def fetchone(self):
if self._index < len(self._rows):
row = self._rows[self._index]
self._index += 1
return row
return None

def fetchall(self):
remaining = self._rows[self._index :]
self._index = len(self._rows)
return remaining


class PolicyEngineDatabase:
"""
A wrapper around the database connection.
Expand Down Expand Up @@ -70,6 +98,22 @@ def _close_pool(self):
except:
pass

def _execute_remote(self, query_args):
"""Execute a query against the remote database using
SQLAlchemy v2 connection-based execution."""
main_query = query_args[0]
params = query_args[1] if len(query_args) > 1 else None
with self.pool.connect() as conn:
if params is not None:
result = conn.exec_driver_sql(main_query, params)
else:
result = conn.exec_driver_sql(main_query)
conn.commit()
# Return a lightweight wrapper that holds
# the fetched results so they survive the
# connection context closing
return _ResultProxy(result)

def query(self, *query):
if self.local:
with sqlite3.connect(self.db_url) as conn:
Expand All @@ -89,7 +133,7 @@ def dict_factory(cursor, row):
main_query = main_query.replace("?", "%s")
query[0] = main_query
try:
return self.pool.execute(*query)
return self._execute_remote(query)
# Except InterfaceError and OperationalError, which are thrown when the connection is lost.
except (
sqlalchemy.exc.InterfaceError,
Expand All @@ -98,7 +142,7 @@ def dict_factory(cursor, row):
try:
self._close_pool()
self._create_pool()
return self.pool.execute(*query)
return self._execute_remote(query)
except Exception as e:
raise e

Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/household_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.utils import hash_object
Expand All @@ -24,7 +24,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None:
f"Invalid household ID: {household_id}. Must be a positive integer."
)

row: LegacyRow | None = database.query(
row: Row | None = database.query(
f"SELECT * FROM household WHERE id = ? AND country_id = ?",
(household_id, country_id),
).fetchone()
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/policy_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.utils import hash_object
Expand Down Expand Up @@ -37,7 +37,7 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None:
raise ValueError("country_id cannot be empty or None")

# If no policy found, this will return None
row: LegacyRow | None = database.query(
row: Row | None = database.query(
"SELECT * FROM policy WHERE country_id = ? AND id = ?",
(country_id, policy_id),
).fetchone()
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/report_output_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_report_output(self, report_output_id: int) -> dict | None:
f"Invalid report output ID: {report_output_id}. Must be a positive integer."
)

row: LegacyRow | None = database.query(
row: Row | None = database.query(
"SELECT * FROM report_outputs WHERE id = ?",
(report_output_id,),
).fetchone()
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/simulation_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_simulation(
f"Invalid simulation ID: {simulation_id}. Must be a positive integer."
)

row: LegacyRow | None = database.query(
row: Row | None = database.query(
"SELECT * FROM simulations WHERE id = ? AND country_id = ?",
(simulation_id, country_id),
).fetchone()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"python-dotenv",
"redis",
"rq",
"sqlalchemy>=1.4,<2",
"sqlalchemy>=2,<3",
"streamlit",
"werkzeug",
"Flask-Caching>=2,<3",
Expand Down
10 changes: 5 additions & 5 deletions tests/to_refactor/python/test_household_routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import json
from unittest.mock import MagicMock, patch
from sqlalchemy.engine.row import LegacyRow

from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS

Expand All @@ -16,8 +15,9 @@
class TestGetHousehold:
def test_get_existing_household(self, rest_client, mock_database):
"""Test getting an existing household."""
# Mock database response
mock_row = MagicMock(spec=LegacyRow)
# Mock database response as a dict-like object
# (SQLAlchemy v2 Row objects support dict() via ._mapping)
mock_row = MagicMock()
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
mock_row.keys.return_value = valid_db_row.keys()
mock_database.query().fetchone.return_value = mock_row
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_create_household_success(
):
"""Test successfully creating a new household."""
# Mock database responses
mock_row = MagicMock(spec=LegacyRow)
mock_row = MagicMock()
mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x]
mock_database.query().fetchone.return_value = mock_row

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_update_household_success(
):
"""Test successfully updating an existing household."""
# Mock getting existing household
mock_row = MagicMock(spec=LegacyRow)
mock_row = MagicMock()
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
mock_row.keys.return_value = valid_db_row.keys()
mock_database.query().fetchone.return_value = mock_row
Expand Down
194 changes: 194 additions & 0 deletions tests/unit/data/test_sqlalchemy_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Tests for SQLAlchemy v2 compatibility.

These tests verify that the database layer works correctly with
SQLAlchemy v2, specifically:
- The _ResultProxy wrapper provides fetchone()/fetchall() on eagerly
fetched results.
- The remote (non-local) query path uses connection-based execution
instead of the removed engine.execute().
- Row objects returned from the remote path support dict-like access
(dict(row) and row["key"]).
"""

import pytest
import sqlalchemy

from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase


class TestSQLAlchemyVersion:
"""Verify that SQLAlchemy v2 is installed."""

def test_sqlalchemy_version_is_v2(self):
major = int(sqlalchemy.__version__.split(".")[0])
assert (
major >= 2
), f"Expected SQLAlchemy v2+, got {sqlalchemy.__version__}"


class TestResultProxy:
"""Test the _ResultProxy wrapper that bridges SQLAlchemy v2
connection-scoped results with the existing query() API."""

def test_fetchone_returns_dict_like_rows(self):
"""Rows returned by fetchone() should support dict() and
key-based access."""
engine = sqlalchemy.create_engine("sqlite://")
with engine.connect() as conn:
conn.exec_driver_sql(
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
)
conn.exec_driver_sql("INSERT INTO test VALUES (1, 'hello')")
result = conn.exec_driver_sql("SELECT * FROM test")
proxy = _ResultProxy(result)

row = proxy.fetchone()
assert row is not None
assert dict(row) == {"id": 1, "name": "hello"}
assert row["id"] == 1
assert row["name"] == "hello"

def test_fetchone_returns_none_when_exhausted(self):
engine = sqlalchemy.create_engine("sqlite://")
with engine.connect() as conn:
conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")
result = conn.exec_driver_sql("SELECT * FROM test")
proxy = _ResultProxy(result)

assert proxy.fetchone() is None

def test_fetchall_returns_all_rows(self):
engine = sqlalchemy.create_engine("sqlite://")
with engine.connect() as conn:
conn.exec_driver_sql(
"CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)"
)
conn.exec_driver_sql("INSERT INTO test VALUES (1, 'a')")
conn.exec_driver_sql("INSERT INTO test VALUES (2, 'b')")
conn.exec_driver_sql("INSERT INTO test VALUES (3, 'c')")
result = conn.exec_driver_sql("SELECT * FROM test")
proxy = _ResultProxy(result)

rows = proxy.fetchall()
assert len(rows) == 3
assert dict(rows[0]) == {"id": 1, "val": "a"}
assert dict(rows[2]) == {"id": 3, "val": "c"}

def test_fetchone_then_fetchall_respects_cursor_position(self):
engine = sqlalchemy.create_engine("sqlite://")
with engine.connect() as conn:
conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")
conn.exec_driver_sql("INSERT INTO test VALUES (1)")
conn.exec_driver_sql("INSERT INTO test VALUES (2)")
conn.exec_driver_sql("INSERT INTO test VALUES (3)")
result = conn.exec_driver_sql("SELECT * FROM test")
proxy = _ResultProxy(result)

first = proxy.fetchone()
assert dict(first) == {"id": 1}
remaining = proxy.fetchall()
assert len(remaining) == 2
assert dict(remaining[0]) == {"id": 2}

def test_result_proxy_for_insert_statement(self):
"""INSERT statements produce no rows; _ResultProxy should
handle this gracefully."""
engine = sqlalchemy.create_engine("sqlite://")
with engine.connect() as conn:
conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")
result = conn.exec_driver_sql("INSERT INTO test VALUES (1)")
proxy = _ResultProxy(result)

assert proxy.fetchone() is None
assert proxy.fetchall() == []


class TestRemoteQueryPath:
"""Test the non-local query path that uses SQLAlchemy engine
with connection-based execution (v2 pattern)."""

def _make_remote_db(self):
"""Create a PolicyEngineDatabase-like object that uses
a SQLAlchemy engine (the 'remote' path) but backed by
in-memory SQLite for testing."""
db = PolicyEngineDatabase.__new__(PolicyEngineDatabase)
db.local = False
db.pool = sqlalchemy.create_engine("sqlite://")
# Initialize schema using the remote path
with db.pool.connect() as conn:
conn.exec_driver_sql(
"CREATE TABLE test_table "
"(id INTEGER PRIMARY KEY, name TEXT, value REAL)"
)
conn.commit()
return db

def test_remote_insert_and_select(self):
"""Test INSERT then SELECT through the remote query path."""
db = self._make_remote_db()

# Note: remote path converts ? to %s for MySQL, but SQLite
# uses ? natively. Since exec_driver_sql passes to the DBAPI
# driver directly and SQLite's driver uses ?, we need to
# test with the actual query() method which does the conversion.
# For SQLite DBAPI, ? is the native marker.

# Use exec_driver_sql directly to bypass ?->%s conversion
# (which would break SQLite)
db._execute_remote(
[
"INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)",
(1, "test", 3.14),
]
)

result = db._execute_remote(
["SELECT * FROM test_table WHERE id = ?", (1,)]
)
row = result.fetchone()
assert row is not None
assert row["id"] == 1
assert row["name"] == "test"
assert row["value"] == 3.14
assert dict(row) == {"id": 1, "name": "test", "value": 3.14}

def test_remote_select_no_results(self):
db = self._make_remote_db()
result = db._execute_remote(
["SELECT * FROM test_table WHERE id = ?", (999,)]
)
assert result.fetchone() is None

def test_remote_update(self):
db = self._make_remote_db()
db._execute_remote(
[
"INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)",
(1, "original", 1.0),
]
)
db._execute_remote(
[
"UPDATE test_table SET name = ? WHERE id = ?",
("updated", 1),
]
)
result = db._execute_remote(
["SELECT * FROM test_table WHERE id = ?", (1,)]
)
row = result.fetchone()
assert row["name"] == "updated"

def test_remote_delete(self):
db = self._make_remote_db()
db._execute_remote(
[
"INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)",
(1, "to_delete", 0.0),
]
)
db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)])
result = db._execute_remote(
["SELECT * FROM test_table WHERE id = ?", (1,)]
)
assert result.fetchone() is None
1 change: 0 additions & 1 deletion tests/unit/services/test_household_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import json
from unittest.mock import MagicMock
from sqlalchemy.engine.row import LegacyRow
import re

from policyengine_api.services.household_service import HouseholdService
Expand Down
Loading