Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ormar/databases/query_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,19 @@ async def execute(self, query: Executable) -> Any:
Execute a query (INSERT, UPDATE, DELETE).

:param query: SQLAlchemy query expression
:return: For INSERT, returns last row id; for UPDATE/DELETE, returns row count
:return: For INSERT, the inserted primary key or ``None`` if the backend
cannot return one (e.g. Oracle MySQL inserting into a
non-AUTO_INCREMENT pk with a server default — no RETURNING support).
For UPDATE/DELETE, the row count.
"""
result: CursorResult[Any] = await self._connection.execute(query)

# For INSERT queries, try to get the inserted primary key
# PostgreSQL/MySQL use inserted_primary_key, SQLite uses lastrowid
if result.context and result.context.isinsert: # pragma: no cover
# For INSERT queries, try to get the inserted primary key via the
# dialect's best-available mechanism (RETURNING on PostgreSQL / SQLite
# 3.35+ / MariaDB 10.5+, LAST_INSERT_ID() on MySQL AUTO_INCREMENT).
# Do NOT fall back to rowcount here: rowcount is not a pk, and
# returning it would silently corrupt `Model.pk` in `save()`.
if result.context and result.context.isinsert:
if result.inserted_primary_key:
pk_value = result.inserted_primary_key[0]
if pk_value is not None:
Expand All @@ -76,6 +82,8 @@ async def execute(self, query: Executable) -> Any:
if hasattr(result, "lastrowid") and result.lastrowid: # pragma: no cover
return result.lastrowid

return None # pragma: no cover

return result.rowcount if result.rowcount is not None else 0

async def execute_many(
Expand Down
22 changes: 20 additions & 2 deletions ormar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,34 @@ async def save(self: T) -> T:
expr = self.ormar_config.table.insert()
expr = expr.values(**self_fields)

pkname = self.ormar_config.pkname
pk_returned_from_insert = False
pk = await self._execute_query(expr)
if pk and isinstance(pk, self.pk_type()):
setattr(self, self.ormar_config.pkname, pk)
setattr(self, pkname, pk)
pk_returned_from_insert = True

if self.pk is None:
raise ModelPersistenceError( # pragma: no cover
f"Could not recover the generated primary key for "
f"{self.__class__.__name__} after INSERT. This happens on "
"backends that lack RETURNING support for server-side "
"defaults on non-AUTO_INCREMENT primary keys — most notably "
"Oracle MySQL, which does not implement RETURNING in any "
"version. Use autoincrement=True, provide the primary key "
"client-side, or switch to a RETURNING-capable backend "
"(PostgreSQL, SQLite 3.35+, MariaDB 10.5+)."
)

self.set_save_status(True)
# refresh server side defaults
# refresh server-side defaults — but skip the pk if the insert already
# returned it, so save() stays a single round-trip when the pk is the
# only server_default field.
if any(
field.server_default is not None
for name, field in self.ormar_config.model_fields.items()
if name not in self_fields
and not (pk_returned_from_insert and name == pkname)
):
await self.load()

Expand Down
179 changes: 179 additions & 0 deletions tests/test_model_methods/test_save_server_default_pk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Tests for ``Model.save()`` behavior around ``server_default`` on the pk.

Covers two closely related concerns:

1. N+1 fix (PR #919) — when the pk is the only ``server_default`` field, the
INSERT's RETURNING clause already provides the pk, so ``save()`` must not
issue a second SELECT to reload the model.

2. Pk-recovery loud-fail — on backends that cannot return a server-generated
pk (Oracle MySQL has no RETURNING), ``save()`` must raise
``ModelPersistenceError`` rather than silently storing a bogus pk (the old
behavior was to return ``rowcount`` from the executor, which ``save()``
then mistook for the pk).
"""

from typing import Any, List

import pytest
from sqlalchemy import event, text

import ormar
from ormar.exceptions import ModelPersistenceError
from tests.lifespan import init_tests
from tests.settings import create_config

base_ormar_config = create_config()

_IS_MYSQL = "mysql" in base_ormar_config.database.url


class ServerDefaultPk(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="server_default_pk")

id: int = ormar.Integer(
primary_key=True, autoincrement=False, server_default=text("100")
)
name: str = ormar.String(max_length=100)


class ServerDefaultNonPk(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="server_default_nonpk")

id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
company: str = ormar.String(max_length=100, server_default="Acme")


class ServerDefaultPkAndNonPk(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="server_default_pk_and_nonpk")

id: int = ormar.Integer(
primary_key=True, autoincrement=False, server_default=text("200")
)
name: str = ormar.String(max_length=100)
company: str = ormar.String(max_length=100, server_default="Acme")


create_test_database = init_tests(base_ormar_config)


class _StatementCounter:
"""Records every statement executed on the sqlalchemy engine."""

def __init__(self) -> None:
self.statements: List[str] = []

def __enter__(self) -> "_StatementCounter":
sync_engine = base_ormar_config.database.engine.sync_engine

def before_cursor_execute(
conn: Any,
cursor: Any,
statement: str,
parameters: Any,
context: Any,
executemany: bool,
) -> None:
self.statements.append(statement)

self._listener = before_cursor_execute
self._sync_engine = sync_engine
event.listen(sync_engine, "before_cursor_execute", self._listener)
return self

def __exit__(self, *exc: Any) -> None:
event.remove(self._sync_engine, "before_cursor_execute", self._listener)


def _table_selects(statements: List[str], tablename: str) -> List[str]:
return [
s
for s in statements
if s.lstrip().upper().startswith("SELECT") and tablename in s
]


@pytest.mark.asyncio
@pytest.mark.skipif(
_IS_MYSQL,
reason=(
"Oracle MySQL has no RETURNING clause, so a server_default on a "
"non-AUTO_INCREMENT pk cannot be recovered — covered by "
"test_save_raises_when_server_default_pk_cannot_be_recovered instead."
),
)
async def test_save_does_not_reload_when_only_pk_has_server_default(): # noqa: E501 # pragma: no cover
"""INSERT returns the server-generated pk, so no SELECT should follow."""
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
with _StatementCounter() as counter:
instance = ServerDefaultPk(name="first")
await instance.save()

selects = _table_selects(
counter.statements, ServerDefaultPk.ormar_config.tablename
)
assert instance.pk is not None
assert selects == [], counter.statements


@pytest.mark.asyncio
async def test_save_still_reloads_when_non_pk_has_server_default():
"""Regression guard: non-pk server defaults must still trigger a reload."""
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
with _StatementCounter() as counter:
instance = ServerDefaultNonPk(id=1, name="first")
await instance.save()

selects = _table_selects(
counter.statements, ServerDefaultNonPk.ormar_config.tablename
)
assert instance.company == "Acme"
assert len(selects) == 1, counter.statements


@pytest.mark.asyncio
@pytest.mark.skipif(
_IS_MYSQL,
reason=(
"Oracle MySQL has no RETURNING clause, so a server_default on a "
"non-AUTO_INCREMENT pk cannot be recovered — covered by "
"test_save_raises_when_server_default_pk_cannot_be_recovered instead."
),
)
async def test_save_reloads_once_when_both_pk_and_non_pk_have_server_default(): # noqa: E501 # pragma: no cover
"""Mixed case: still need exactly one reload for the non-pk column."""
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
with _StatementCounter() as counter:
instance = ServerDefaultPkAndNonPk(name="first")
await instance.save()

selects = _table_selects(
counter.statements, ServerDefaultPkAndNonPk.ormar_config.tablename
)
assert instance.pk is not None
assert instance.company == "Acme"
assert len(selects) == 1, counter.statements


@pytest.mark.asyncio
@pytest.mark.skipif(
not _IS_MYSQL,
reason=(
"This loud-fail path only fires on backends that cannot return a "
"server-generated pk. RETURNING-capable backends (PostgreSQL, "
"SQLite 3.35+, MariaDB 10.5+) succeed here."
),
)
async def test_save_raises_when_server_default_pk_cannot_be_recovered(): # noqa: E501 # pragma: no cover
"""Oracle MySQL: no RETURNING → save() must raise, not silently store a
bogus pk (the old bug was to coerce rowcount into ``Model.pk``)."""
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
instance = ServerDefaultPk(name="first")
with pytest.raises(ModelPersistenceError, match="primary key"):
await instance.save()
assert instance.pk is None