Skip to content

Commit fa32964

Browse files
✨ add object field convenience accessors
1 parent d6bfc5e commit fa32964

17 files changed

+167
-48
lines changed

mindee/parsing/v2/field/object_field.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from typing import TYPE_CHECKING, cast
12
from mindee.parsing.common.string_dict import StringDict
23
from mindee.parsing.v2.field.base_field import BaseField
34
from mindee.parsing.v2.field.dynamic_field import FieldType
45
from mindee.parsing.v2.field.inference_fields import InferenceFields
56

7+
if TYPE_CHECKING:
8+
from mindee.parsing.v2.field.list_field import ListField
9+
from mindee.parsing.v2.field.simple_field import SimpleField
10+
611

712
class ObjectField(BaseField):
813
"""Object field containing multiple fields."""
@@ -37,5 +42,92 @@ def multi_str(self) -> str:
3742
first = False
3843
return out_str
3944

45+
@property
46+
def simple_fields(self) -> dict[str, "SimpleField"]:
47+
"""
48+
Extract and return all SimpleField fields from the `fields` attribute.
49+
50+
:return: A dictionary containing all fields that have a type of `FieldType.SIMPLE`.
51+
:rtype: dict[str, SimpleField]
52+
"""
53+
simple_fields = {}
54+
for field_key, field_value in self.fields.items():
55+
if field_value.field_type == FieldType.SIMPLE:
56+
simple_fields[field_key] = cast("SimpleField", field_value)
57+
return simple_fields
58+
59+
@property
60+
def list_fields(self) -> dict[str, "ListField"]:
61+
"""
62+
Retrieves all ListField fields from the `fields` attribute.
63+
64+
:return: A dictionary containing all fields of type `LIST`, with keys
65+
representing field keys and values being the corresponding field
66+
objects.
67+
:rtype: dict[str, ListField]
68+
"""
69+
list_fields = {}
70+
for field_key, field_value in self.fields.items():
71+
if field_value.field_type == FieldType.LIST:
72+
list_fields[field_key] = cast("ListField", field_value)
73+
return list_fields
74+
75+
@property
76+
def object_fields(self) -> dict[str, "ObjectField"]:
77+
"""
78+
Retrieves all ObjectField fields from the `fields` attribute of the instance.
79+
80+
:returns: A dictionary containing fields of type `FieldType.OBJECT`. The keys represent
81+
the field names, and the values are corresponding ObjectField objects.
82+
:rtype: dict[str, ObjectField]
83+
"""
84+
object_fields = {}
85+
for field_key, field_value in self.fields.items():
86+
if field_value.field_type == FieldType.OBJECT:
87+
object_fields[field_key] = cast("ObjectField", field_value)
88+
return object_fields
89+
90+
def get_simple_field(self, field_name: str) -> "SimpleField":
91+
"""
92+
Retrieves a SimpleField from the provided field name.
93+
94+
:param field_name: The name of the field to retrieve.
95+
:type field_name: str
96+
:return: The SimpleField object corresponding to the given field name.
97+
:rtype: SimpleField
98+
:raises ValueError: If the specified field is not of type SimpleField.
99+
"""
100+
if self.fields[field_name].field_type != FieldType.SIMPLE:
101+
raise ValueError(f"Field {field_name} is not a SimpleField.")
102+
return cast("SimpleField", self.fields[field_name])
103+
104+
def get_list_field(self, field_name: str) -> "ListField":
105+
"""
106+
Retrieves the ``ListField`` for the specified field name.
107+
108+
:param field_name: The name of the field to retrieve.
109+
:type field_name: str
110+
:return: The corresponding ``ListField`` for the given field name.
111+
:rtype: ListField
112+
:raises ValueError: If the field is not of type ``ListField``.
113+
"""
114+
if self.fields[field_name].field_type != FieldType.LIST:
115+
raise ValueError(f"Field {field_name} is not a ListField.")
116+
return cast("ListField", self.fields[field_name])
117+
118+
def get_object_field(self, field_name: str) -> "ObjectField":
119+
"""
120+
Retrieves the `ObjectField` associated with the specified field name.
121+
122+
:param field_name: The name of the field to retrieve.
123+
:type field_name: str
124+
:return: The `ObjectField` associated with the given field name.
125+
:rtype: ObjectField
126+
:raises ValueError: If the field specified by `field_name` is not an `ObjectField`.
127+
"""
128+
if self.fields[field_name].field_type != FieldType.OBJECT:
129+
raise ValueError(f"Field {field_name} is not an ObjectField.")
130+
return cast("ObjectField", self.fields[field_name])
131+
40132
def __str__(self) -> str:
41133
return self.single_str()

tests/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
V2_DATA_DIR = ROOT_DATA_DIR / "v2"
1919
V2_PRODUCT_DATA_DIR = V2_DATA_DIR / "products"
20-
V2_UTILITIES_DATA_DIR = V2_DATA_DIR / "utilities"
2120

2221

2322
def clear_envvars(monkeypatch) -> None:

tests/v2/input/__init__.py

Whitespace-only changes.

tests/v2/input/test_inference_parameters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from tests.utils import V2_DATA_DIR
1212

1313
expected_data_schema_dict = json.loads(
14-
(V2_DATA_DIR / "inference" / "data_schema_replace_param.json").read_text()
14+
(
15+
V2_DATA_DIR / "products" / "extraction" / "data_schema_replace_param.json"
16+
).read_text()
1517
)
1618
expected_data_schema_str = json.dumps(
1719
expected_data_schema_dict, indent=None, sort_keys=True

tests/v2/input/test_local_response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
@pytest.fixture
1111
def file_path() -> Path:
12-
return V2_DATA_DIR / "inference" / "standard_field_types.json"
12+
return V2_DATA_DIR / "products" / "extraction" / "standard_field_types.json"
1313

1414

1515
def _assert_local_response(local_response):
1616
fake_hmac_signing = "ogNjY44MhvKPGTtVsI8zG82JqWQa68woYQH"
17-
signature = "f390d9f7f57ac04f47b6309d8a40236b0182610804fc20e91b1f6028aaca07a7"
17+
signature = "e51bdf80f1a08ed44ee161100fc30a25cb35b4ede671b0a575dc9064a3f5dbf1"
1818

1919
assert local_response._file is not None
2020
assert not local_response.is_valid_hmac_signature(

tests/v2/parsing/test_inference_response.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
from mindee import InferenceResponse
88
from mindee.parsing.v2 import InferenceActiveOptions
9-
from mindee.parsing.v2.field import FieldConfidence, ListField, ObjectField, SimpleField
9+
from mindee.parsing.v2.field.field_confidence import FieldConfidence
10+
from mindee.parsing.v2.field.list_field import ListField
11+
from mindee.parsing.v2.field.object_field import ObjectField
12+
from mindee.parsing.v2.field.simple_field import SimpleField
1013
from mindee.parsing.v2.field.inference_fields import InferenceFields
1114
from mindee.parsing.v2.inference import Inference
1215
from mindee.parsing.v2.inference_file import InferenceFile
@@ -27,14 +30,14 @@ def _get_samples(json_path: Path, rst_path: Path) -> Tuple[dict, str]:
2730

2831

2932
def _get_inference_samples(name: str) -> Tuple[dict, str]:
30-
json_path = V2_DATA_DIR / "inference" / f"{name}.json"
31-
rst_path = V2_DATA_DIR / "inference" / f"{name}.rst"
33+
json_path = V2_DATA_DIR / "products" / "extraction" / f"{name}.json"
34+
rst_path = V2_DATA_DIR / "products" / "extraction" / f"{name}.rst"
3235
return _get_samples(json_path, rst_path)
3336

3437

3538
def _get_product_samples(product, name: str) -> Tuple[dict, str]:
36-
json_path = V2_DATA_DIR / "products" / product / f"{name}.json"
37-
rst_path = V2_DATA_DIR / "products" / product / f"{name}.rst"
39+
json_path = V2_DATA_DIR / "products" / "extraction" / product / f"{name}.json"
40+
rst_path = V2_DATA_DIR / "products" / "extraction" / product / f"{name}.rst"
3841
return _get_samples(json_path, rst_path)
3942

4043

@@ -53,42 +56,54 @@ def test_deep_nested_fields():
5356
response.inference.result.fields["field_object"].fields["sub_object_object"],
5457
ObjectField,
5558
)
59+
fields = response.inference.result.fields
60+
assert isinstance(fields.get("field_object"), ObjectField)
5661
assert isinstance(
57-
response.inference.result.fields["field_object"]
58-
.fields["sub_object_object"]
59-
.fields,
62+
fields.get("field_object").get_simple_field("sub_object_simple"), SimpleField
63+
)
64+
assert isinstance(
65+
fields.get("field_object").get_list_field("sub_object_list"), ListField
66+
)
67+
assert isinstance(
68+
fields.get("field_object").get_object_field("sub_object_object"), ObjectField
69+
)
70+
assert len(fields.get("field_object").simple_fields) == 1
71+
assert len(fields.get("field_object").list_fields) == 1
72+
assert len(fields.get("field_object").object_fields) == 1
73+
assert isinstance(
74+
fields["field_object"].fields["sub_object_object"].fields,
6075
dict,
6176
)
6277
assert isinstance(
63-
response.inference.result.fields["field_object"]
78+
fields["field_object"]
6479
.fields["sub_object_object"]
6580
.fields["sub_object_object_sub_object_list"],
6681
ListField,
6782
)
6883
assert isinstance(
69-
response.inference.result.fields["field_object"]
84+
fields["field_object"]
7085
.fields["sub_object_object"]
7186
.fields["sub_object_object_sub_object_list"]
7287
.items,
7388
list,
7489
)
7590
assert isinstance(
76-
response.inference.result.fields["field_object"]
91+
fields["field_object"]
7792
.fields["sub_object_object"]
7893
.fields["sub_object_object_sub_object_list"]
7994
.items[0],
8095
ObjectField,
8196
)
8297
assert isinstance(
83-
response.inference.result.fields["field_object"]
98+
fields["field_object"]
8499
.fields["sub_object_object"]
85100
.fields["sub_object_object_sub_object_list"]
86101
.items[0]
87102
.fields["sub_object_object_sub_object_list_simple"],
88103
SimpleField,
89104
)
90105
assert (
91-
response.inference.result.fields["field_object"]
106+
fields["field_object"]
92107
.fields["sub_object_object"]
93108
.fields["sub_object_object_sub_object_list"]
94109
.items[0]
@@ -299,7 +314,9 @@ def test_text_context_field_is_false() -> None:
299314

300315
@pytest.mark.v2
301316
def test_text_context_field_is_true() -> None:
302-
with open(V2_DATA_DIR / "inference" / "text_context_enabled.json", "r") as file:
317+
with open(
318+
V2_DATA_DIR / "products" / "extraction" / "text_context_enabled.json", "r"
319+
) as file:
303320
json_sample = json.load(file)
304321
response = InferenceResponse(json_sample)
305322
assert isinstance(response.inference.active_options, InferenceActiveOptions)

tests/v2/product/classification/test_classification_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mindee import ClientV2, PathInput
66
from mindee.v2 import ClassificationParameters, ClassificationResponse
7-
from tests.utils import V2_UTILITIES_DATA_DIR
7+
from tests.utils import V2_PRODUCT_DATA_DIR
88

99

1010
@pytest.fixture(scope="session")
@@ -24,7 +24,7 @@ def test_classification_default_sample(
2424
v2_client: ClientV2, classification_model_id: str
2525
):
2626
input_source = PathInput(
27-
V2_UTILITIES_DATA_DIR / "classification" / "default_invoice.jpg"
27+
V2_PRODUCT_DATA_DIR / "classification" / "default_invoice.jpg"
2828
)
2929
response = v2_client.enqueue_and_get_result(
3030
ClassificationResponse,

tests/v2/product/classification/test_classification_response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
ClassificationResponse,
1010
)
1111
from mindee.v2.product.classification.classification_result import ClassificationResult
12-
from tests.utils import V2_UTILITIES_DATA_DIR
12+
from tests.utils import V2_PRODUCT_DATA_DIR
1313

1414

1515
@pytest.mark.v2
1616
def test_classification_single():
1717
input_inference = LocalResponse(
18-
V2_UTILITIES_DATA_DIR / "classification" / "classification_single.json"
18+
V2_PRODUCT_DATA_DIR / "classification" / "classification_single.json"
1919
)
2020
classification_response = input_inference.deserialize_response(
2121
ClassificationResponse

tests/v2/product/crop/test_crop_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mindee import ClientV2, PathInput
66
from mindee.v2 import CropParameters, CropResponse
7-
from tests.utils import V2_UTILITIES_DATA_DIR
7+
from tests.utils import V2_PRODUCT_DATA_DIR
88

99

1010
@pytest.fixture(scope="session")
@@ -21,7 +21,7 @@ def v2_client() -> ClientV2:
2121
@pytest.mark.integration
2222
@pytest.mark.v2
2323
def test_crop_default_sample(v2_client: ClientV2, crop_model_id: str):
24-
input_source = PathInput(V2_UTILITIES_DATA_DIR / "crop" / "default_sample.jpg")
24+
input_source = PathInput(V2_PRODUCT_DATA_DIR / "crop" / "default_sample.jpg")
2525
response = v2_client.enqueue_and_get_result(
2626
CropResponse, input_source, CropParameters(crop_model_id)
2727
)

0 commit comments

Comments
 (0)