Skip to content

Commit 19cbb85

Browse files
committed
Improve error messages and fix restoring optional Typed member
1 parent f338194 commit 19cbb85

7 files changed

Lines changed: 124 additions & 43 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
# 0.8.5
2+
3+
- Fix error that would occur when restoring optional Typed member with no value
4+
- Use a new `RestoreError` exception that improves error messages when the error mode is set to `raise`
5+
- Fix an edge case where a cached result is not pulled when tables are joined
6+
7+
# 0.8.4
8+
9+
- Fix values query if it contained a related column
10+
111
# 0.8.3
212

313
- Remove the `__ref__` member and use`id(self)` instead. State still contains the `__ref__` key.

atomdb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "0.8.4"
1+
version = "0.8.5"

atomdb/base.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,20 @@ def generate_getstate(cls: Type["Model"]) -> GetStateFn:
497497
return generate_function(source, namespace, "__getstate__")
498498

499499

500+
class RestoreError(Exception):
501+
"""An exception raised when an error occurs while restoring a model from state."""
502+
503+
def __init__(self, field: str, cls: Type["Model"], e: Exception):
504+
self.field = field
505+
self.cls = cls
506+
self.exc = e
507+
508+
def __str__(self):
509+
return (
510+
f"Error restoring '{self.field}' on object of type {self.cls}: {self.exc}"
511+
)
512+
513+
500514
def generate_restorestate(cls: Type["Model"]) -> RestoreStateFn:
501515
"""Generate an optimized __restorestate__ function for the given model.
502516
@@ -548,6 +562,7 @@ def generate_restorestate(cls: Type["Model"]) -> RestoreStateFn:
548562

549563
namespace: DictType[str, Any] = {
550564
"default_unflatten": default_unflatten,
565+
"RestoreError": RestoreError,
551566
}
552567
for order, f, m, unflatten in setters:
553568
# Since f is potentially an untrusted input, make sure it is a valid
@@ -569,7 +584,9 @@ def generate_restorestate(cls: Type["Model"]) -> RestoreStateFn:
569584
RelModel = types[0]
570585
if RelModel is not None:
571586
namespace[f"rel_model_{f}"] = RelModel
572-
expr = f"await rel_model_{f}.restore(state['{f}'])"
587+
expr = (
588+
f"await rel_model_{f}.restore(v) if (v := state['{f}']) else None"
589+
)
573590
elif is_primitive_member(m):
574591
# Direct assignment
575592
expr = f"state['{f}']"
@@ -584,21 +601,20 @@ def generate_restorestate(cls: Type["Model"]) -> RestoreStateFn:
584601
expr = f"unflatten_{f}(state['{f}'], scope)"
585602

586603
# Do the assignment
587-
if on_error == "raise":
588-
template.append(f" self.{f} = {expr}")
604+
if on_error == "ignore":
605+
handler = "pass"
606+
elif on_error == "log":
607+
handler = f"self.__log_restore_error__(e, '{f}', state, scope)"
589608
else:
590-
if on_error == "log":
591-
handler = f"self.__log_restore_error__(e, '{f}', state, scope)"
592-
else:
593-
handler = "pass"
594-
template.extend(
595-
[
596-
" try:",
597-
f" self.{f} = {expr}",
598-
" except Exception as e:",
599-
f" {handler}",
600-
]
601-
)
609+
handler = f"raise RestoreError('{f}', self.__class__, e) from e"
610+
template.extend(
611+
[
612+
" try:",
613+
f" self.{f} = {expr}",
614+
" except Exception as e:",
615+
f" {handler}",
616+
]
617+
)
602618

603619
# Update restored state
604620
template.append("self.__restored__ = True")
@@ -778,6 +794,7 @@ def __log_restore_error__(
778794
@classmethod
779795
async def restore(cls: Type[M], state: StateType, **kwargs: Any) -> M:
780796
"""Restore an object from the database state"""
797+
assert state is not None
781798
obj = cls.__new__(cls)
782799
await obj.__restorestate__(state)
783800
return obj
@@ -801,6 +818,8 @@ def flatten(self, v: Any, scope: Optional[ScopeType] = None):
801818
a __py__ field and arguments to reconstruct it. Also see the coercers
802819
803820
"""
821+
if v is None:
822+
return None
804823
if isinstance(v, (date, datetime, time)):
805824
# This is inefficient space wise but still allows queries
806825
s: DictType[str, Any] = {

atomdb/sql.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ModelManager,
6161
ModelMeta,
6262
ModelSerializer,
63+
RestoreError,
6364
RestoreStateFn,
6465
ScopeType,
6566
StateType,
@@ -1934,11 +1935,14 @@ async def get_cached_model(cls: Type[T], pk: Any, state: StateType) -> Optional[
19341935
If the pk is not None an instance of cls.
19351936
19361937
"""
1937-
if cls.__joined_pk__ in state and state[cls.__joined_pk__]:
1938+
cache = cls.objects.cache
1939+
if cls.__joined_pk__ in state and (joined_pk := state[cls.__joined_pk__]):
1940+
obj = cache.get(joined_pk)
1941+
if obj is not None:
1942+
return obj
19381943
return await cls.restore(state) # Restore from joined row result
19391944
if not pk:
19401945
return None
1941-
cache = cls.objects.cache
19421946
obj = cache.get(pk)
19431947
if obj is not None:
19441948
return obj # item is already in the cache
@@ -1986,6 +1990,7 @@ def generate_sql_restorestate(cls: Type["SQLModel"]) -> RestoreStateFn:
19861990
namespace: DictType[str, Any] = {
19871991
"default_unflatten": default_unflatten,
19881992
"get_cached_model": get_cached_model,
1993+
"RestoreError": RestoreError,
19891994
}
19901995

19911996
# The state dict may have data from multiple tables that have been joined
@@ -2048,7 +2053,7 @@ def generate_sql_restorestate(cls: Type["SQLModel"]) -> RestoreStateFn:
20482053
# Only convert if the object has not already been restored
20492054
expr = "\n ".join(
20502055
[
2051-
f"if isinstance(v, rel_model_{f}):",
2056+
f"if v is None or isinstance(v, rel_model_{f}):",
20522057
f" self.{f} = v",
20532058
"else:",
20542059
f" self.{f} = {obj}",
@@ -2067,21 +2072,20 @@ def generate_sql_restorestate(cls: Type["SQLModel"]) -> RestoreStateFn:
20672072
else:
20682073
expr = f"self.{f} = unflatten_{f}(v, scope)"
20692074

2070-
if on_error == "raise":
2071-
template.append(f" {expr}")
2075+
if on_error == "ignore":
2076+
handler = "pass"
2077+
elif on_error == "log":
2078+
handler = f"self.__log_restore_error__(e, '{f}', state, scope)"
20722079
else:
2073-
if on_error == "log":
2074-
handler = f"self.__log_restore_error__(e, '{f}', state, scope)"
2075-
else:
2076-
handler = "pass"
2077-
template.extend(
2078-
[
2079-
" try:",
2080-
f" {expr}",
2081-
" except Exception as e:",
2082-
f" {handler}",
2083-
]
2084-
)
2080+
handler = f"raise RestoreError('{f}', self.__class__, e) from e"
2081+
template.extend(
2082+
[
2083+
" try:",
2084+
f" {expr}",
2085+
" except Exception as e:",
2086+
f" {handler}",
2087+
]
2088+
)
20852089

20862090
# Update restored state
20872091
template.append("self.__restored__ = True")

tests/test_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Model,
2525
ModelManager,
2626
ModelSerializer,
27+
RestoreError,
2728
generate_function,
2829
is_db_field,
2930
is_primitive_member,
@@ -203,8 +204,10 @@ class A(Model):
203204
__on_error__ = "raise"
204205
value = Int()
205206

206-
with pytest.raises(TypeError):
207+
with pytest.raises(RestoreError) as exc_info:
207208
await A.restore({"value": "str"})
209+
restore_error = exc_info.value
210+
assert isinstance(restore_error.exc, TypeError)
208211

209212

210213
async def test_on_error_ignore():
@@ -226,6 +229,7 @@ async def test_on_error_log(caplog):
226229
"""
227230

228231
class C(Model):
232+
__on_error__ = "log"
229233
old_field = Int()
230234
new_field = Int()
231235

tests/test_json.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,25 @@
33
from datetime import date, datetime, time
44
from decimal import Decimal
55

6+
import pytest
67
from atom.api import (
78
Bool,
89
Bytes,
910
ForwardInstance,
1011
Instance,
12+
Int,
1113
List,
1214
Range,
1315
Set,
1416
Str,
1517
Tuple,
18+
Typed,
1619
)
1720

18-
from atomdb.base import JSONModel
21+
from atomdb.base import JSONModel, RestoreError
22+
23+
# Set default to raise
24+
JSONModel.__on_error__ = "raise"
1925

2026

2127
class Dates(JSONModel):
@@ -71,6 +77,17 @@ class Point(JSONModel):
7177
position = Tuple(float)
7278

7379

80+
class Position(JSONModel):
81+
x = Int()
82+
y = Int()
83+
84+
85+
class Rectangle(JSONModel):
86+
height = Range(low=0, value=1)
87+
width = Range(low=0, value=1)
88+
center = Typed(Position)
89+
90+
7491
async def test_json_dates():
7592
now = datetime.now()
7693
obj = Dates(d=now.date(), t=now.time(), dt=now)
@@ -169,3 +186,31 @@ async def test_json_cyclical():
169186
assert r.name == "a"
170187
assert r.related.name == b.name
171188
assert r.related.related == r
189+
190+
191+
async def test_json_optional_typed_model():
192+
# Test save & restore of optional typed member with no value
193+
rect = Rectangle(width=1, height=2)
194+
state = JSONModel.serializer.flatten(rect)
195+
rect = await JSONModel.serializer.unflatten(state)
196+
assert rect.center is None
197+
# Now set the value and make sure it's restored'
198+
rect.center = Position(x=10, y=20)
199+
state = JSONModel.serializer.flatten(rect)
200+
rect = await JSONModel.serializer.unflatten(state)
201+
assert rect.center.x == 10 and rect.center.y == 20
202+
203+
204+
async def test_json_restore_error():
205+
# Test save & restore of optional typed member with no value
206+
rect = Rectangle(width=1, height=2)
207+
state = JSONModel.serializer.flatten(rect)
208+
state["width"] = "Not an int"
209+
with pytest.raises(RestoreError) as exc_info:
210+
await JSONModel.serializer.unflatten(state)
211+
restore_error = exc_info.value
212+
assert (
213+
restore_error.field == "width"
214+
and restore_error.cls is Rectangle
215+
and isinstance(restore_error.exc, TypeError)
216+
)

tests/test_sql.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
SQLModelManager,
5858
)
5959

60+
#: Raise errors in tests
61+
SQLModel.__on_error__ = "raise"
62+
6063

6164
class AbstractUser(SQLModel):
6265
email = Str().tag(length=64)
@@ -95,8 +98,7 @@ class JobRole(SQLModel):
9598
skill = Instance(JobSkill)
9699
tasks = Relation(lambda: JobTask)
97100

98-
check_one_default = sa.schema.DDL(
99-
"""
101+
check_one_default = sa.schema.DDL("""
100102
CREATE OR REPLACE FUNCTION check_one_default() RETURNS TRIGGER
101103
LANGUAGE plpgsql
102104
AS $$
@@ -107,15 +109,12 @@ class JobRole(SQLModel):
107109
END IF;
108110
RETURN NEW;
109111
END;
110-
$$;"""
111-
)
112+
$$;""")
112113

113-
trigger = sa.schema.DDL(
114-
"""
114+
trigger = sa.schema.DDL("""
115115
CREATE CONSTRAINT TRIGGER check_default_role AFTER INSERT OR UPDATE
116116
OF "default" ON "test_sql.JobRole"
117-
FOR EACH ROW EXECUTE PROCEDURE check_one_default();"""
118-
)
117+
FOR EACH ROW EXECUTE PROCEDURE check_one_default();""")
119118

120119
class Meta:
121120
triggers = [

0 commit comments

Comments
 (0)