Skip to content

Commit d0ed75a

Browse files
committed
Fixes auto name bug and adds clearer error for missing custom file or function
1 parent a599d44 commit d0ed75a

File tree

4 files changed

+110
-29
lines changed

4 files changed

+110
-29
lines changed

ratapi/inputs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ def get_handle(self, index: int):
7777
"""
7878
custom_file = self.files[index]
7979
full_path = os.path.join(custom_file["path"], custom_file["filename"])
80+
81+
if not os.path.isfile(full_path):
82+
raise FileNotFoundError(f"The custom file ({custom_file['name']}) does not have a valid path.")
83+
84+
if not custom_file["function_name"] and custom_file["language"] != Languages.Matlab:
85+
raise ValueError(f"The custom file ({custom_file['name']}) does not have a valid function name.")
86+
8087
if custom_file["language"] == Languages.Python:
8188
file_handle = get_python_handle(custom_file["filename"], custom_file["function_name"], custom_file["path"])
8289
elif custom_file["language"] == Languages.Matlab:

ratapi/models.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pathlib
44
import warnings
5-
from itertools import count
5+
from contextlib import suppress
66
from typing import Any
77

88
import numpy as np
@@ -18,14 +18,41 @@
1818

1919

2020
# Create a counter for each model
21-
background_number = count(1)
22-
contrast_number = count(1)
23-
custom_file_number = count(1)
24-
data_number = count(1)
25-
domain_contrast_number = count(1)
26-
layer_number = count(1)
27-
parameter_number = count(1)
28-
resolution_number = count(1)
21+
background_number = ["Background", 0]
22+
contrast_number = ["Contrast", 0]
23+
custom_file_number = ["Custom File", 0]
24+
data_number = ["Data", 0]
25+
domain_contrast_number = ["Domain Contrast", 0]
26+
layer_number = ["Layer", 0]
27+
parameter_number = ["Parameter", 0]
28+
resolution_number = ["Resolution", 0]
29+
30+
_model_counter = {
31+
"Background": background_number,
32+
"Contrast": contrast_number,
33+
"ContrastWithRatio": contrast_number,
34+
"CustomFile": custom_file_number,
35+
"Data": data_number,
36+
"DomainContrast": domain_contrast_number,
37+
"Layer": layer_number,
38+
"AbsorptionLayer": layer_number,
39+
"Parameter": parameter_number,
40+
"ProtectedParameter": parameter_number,
41+
"Resolution": resolution_number,
42+
}
43+
44+
45+
def _model_name_factory(model_name: str) -> str:
46+
"""Generate a unique name for model using a global counter.
47+
48+
Parameters
49+
----------
50+
model_name : str
51+
The name of the model class.
52+
"""
53+
title, number = _model_counter[model_name]
54+
_model_counter[model_name][1] += 1
55+
return f"New {title} {(number + 1)}"
2956

3057

3158
class RATModel(BaseModel, validate_assignment=True, extra="forbid"):
@@ -38,6 +65,25 @@ def __repr__(self):
3865
)
3966
return f"{self.__repr_name__()}({fields_repr})"
4067

68+
@field_validator("name", mode="after", check_fields=False)
69+
@classmethod
70+
def update_counter(cls, name: str) -> str:
71+
"""Update the auto name counter if a similar name is manually given.
72+
73+
Parameters
74+
----------
75+
name : str
76+
The name of the model.
77+
"""
78+
title, number = _model_counter[cls.__name__]
79+
prefix = f"New {title} "
80+
if name.startswith(prefix):
81+
with suppress(ValueError):
82+
new_number = int(name[len(prefix) :])
83+
if new_number > number:
84+
_model_counter[cls.__name__][1] = new_number
85+
return name
86+
4187
def __str__(self):
4288
table = prettytable.PrettyTable()
4389
table.field_names = [key.replace("_", " ") for key in self.display_fields]
@@ -116,7 +162,7 @@ class Background(Signal):
116162
117163
"""
118164

119-
name: str = Field(default_factory=lambda: f"New Background {next(background_number)}", min_length=1)
165+
name: str = Field(default_factory=lambda: _model_name_factory("Background"), min_length=1)
120166

121167
@model_validator(mode="after")
122168
def check_unsupported_parameters(self):
@@ -173,7 +219,7 @@ class Contrast(RATModel):
173219
174220
"""
175221

176-
name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1)
222+
name: str = Field(default_factory=lambda: _model_name_factory("Contrast"), min_length=1)
177223
data: str = ""
178224
background: str = ""
179225
background_action: BackgroundActions = BackgroundActions.Add
@@ -255,7 +301,7 @@ class ContrastWithRatio(RATModel):
255301
256302
"""
257303

258-
name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1)
304+
name: str = Field(default_factory=lambda: _model_name_factory("ContrastWithRatio"), min_length=1)
259305
data: str = ""
260306
background: str = ""
261307
background_action: BackgroundActions = BackgroundActions.Add
@@ -309,7 +355,7 @@ class CustomFile(RATModel):
309355
310356
"""
311357

312-
name: str = Field(default_factory=lambda: f"New Custom File {next(custom_file_number)}", min_length=1)
358+
name: str = Field(default_factory=lambda: _model_name_factory("CustomFile"), min_length=1)
313359
filename: str = ""
314360
function_name: str = ""
315361
language: Languages = Languages.Python
@@ -348,7 +394,7 @@ class Data(RATModel, arbitrary_types_allowed=True):
348394
349395
"""
350396

351-
name: str = Field(default_factory=lambda: f"New Data {next(data_number)}", min_length=1)
397+
name: str = Field(default_factory=lambda: _model_name_factory("Data"), min_length=1)
352398
data: np.ndarray = np.empty([0, 3])
353399
data_range: list[float] = Field(default=[], min_length=2, max_length=2)
354400
simulation_range: list[float] = Field(default=[], min_length=2, max_length=2)
@@ -453,7 +499,7 @@ class DomainContrast(RATModel):
453499
454500
"""
455501

456-
name: str = Field(default_factory=lambda: f"New Domain Contrast {next(domain_contrast_number)}", min_length=1)
502+
name: str = Field(default_factory=lambda: _model_name_factory("DomainContrast"), min_length=1)
457503
model: list[str] = []
458504

459505
def __str__(self):
@@ -483,7 +529,7 @@ class Layer(RATModel, populate_by_name=True):
483529
484530
"""
485531

486-
name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1)
532+
name: str = Field(default_factory=lambda: _model_name_factory("Layer"), min_length=1)
487533
thickness: str
488534
SLD: str = Field(validation_alias="SLD_real")
489535
roughness: str
@@ -522,7 +568,7 @@ class AbsorptionLayer(RATModel, populate_by_name=True):
522568
523569
"""
524570

525-
name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1)
571+
name: str = Field(default_factory=lambda: _model_name_factory("AbsorptionLayer"), min_length=1)
526572
thickness: str
527573
SLD_real: str = Field(validation_alias="SLD")
528574
SLD_imaginary: str
@@ -555,7 +601,7 @@ class Parameter(RATModel):
555601
556602
"""
557603

558-
name: str = Field(default_factory=lambda: f"New Parameter {next(parameter_number)}", min_length=1)
604+
name: str = Field(default_factory=lambda: _model_name_factory("Parameter"), min_length=1)
559605
min: float = 0.0
560606
value: float = 0.0
561607
max: float = 0.0
@@ -638,7 +684,7 @@ class Resolution(Signal):
638684
639685
"""
640686

641-
name: str = Field(default_factory=lambda: f"New Resolution {next(resolution_number)}", min_length=1)
687+
name: str = Field(default_factory=lambda: _model_name_factory("Resolution"), min_length=1)
642688

643689
@field_validator("type")
644690
@classmethod

tests/test_inputs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pathlib
44
import pickle
5+
import tempfile
6+
from unittest.mock import patch
57

68
import numpy as np
79
import pytest
@@ -675,6 +677,30 @@ def test_make_controls(standard_layers_controls) -> None:
675677
check_controls_equal(controls, standard_layers_controls)
676678

677679

680+
@patch("ratapi.wrappers.MatlabWrapper")
681+
def test_file_handles(wrapper):
682+
handle = FileHandles([ratapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")])
683+
684+
with pytest.raises(FileNotFoundError, match="The custom file \\(Test Custom File\\) does not have a valid path."):
685+
handle.get_handle(0)
686+
687+
with tempfile.NamedTemporaryFile("w", suffix=".dll") as f:
688+
tmp_file = pathlib.Path(f.name)
689+
handle.files[0]["path"] = tmp_file.parent
690+
handle.files[0]["filename"] = tmp_file.name
691+
handle.files[0]["function_name"] = ""
692+
# No function name should throw exception
693+
with pytest.raises(
694+
ValueError, match="The custom file \\(Test Custom File\\) does not have a valid function name."
695+
):
696+
handle.get_handle(0)
697+
698+
# Matlab does not need function name
699+
handle.files[0]["language"] = "matlab"
700+
handle.get_handle(0)
701+
wrapper.assert_called()
702+
703+
678704
def check_problem_equal(actual_problem, expected_problem) -> None:
679705
"""Compare two instances of the "problem" object for equality."""
680706
scalar_fields = [

tests/test_models.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,23 @@ def test_default_names(model: Callable, model_name: str, model_params: dict) ->
3232
format: "New <model name> <integer>".
3333
"""
3434
model_1 = model(**model_params)
35+
prefix = f"New {model_name} "
36+
assert model_1.name.startswith(prefix)
37+
index = int(model_1.name[len(prefix) :])
38+
3539
model_2 = model(**model_params)
3640
model_3 = model(name="Given Name", **model_params)
3741
model_4 = model(**model_params)
3842

39-
assert model_1.name == f"New {model_name} 1"
40-
assert model_2.name == f"New {model_name} 2"
43+
assert model_1.name == f"New {model_name} {index}"
44+
assert model_2.name == f"New {model_name} {index + 1}"
4145
assert model_3.name == "Given Name"
42-
assert model_4.name == f"New {model_name} 3"
46+
assert model_4.name == f"New {model_name} {index + 2}"
47+
48+
# If user adds name in similar format. The next auto number will take it into account.
49+
model(name=f"{prefix}{index + 20}", **model_params)
50+
model_5 = model(**model_params)
51+
assert model_5.name == f"New {model_name} {index + 21}"
4352

4453

4554
@pytest.mark.parametrize(
@@ -100,13 +109,6 @@ def test_initialise_with_extra_fields(self, model: Callable, model_params: dict)
100109
model(new_field=1, **model_params)
101110

102111

103-
# def test_custom_file_path_is_absolute() -> None:
104-
# """If we use provide a relative path to the custom file model, it should be converted to an absolute path."""
105-
# relative_path = pathlib.Path("./relative_path")
106-
# custom_file = ratapi.models.CustomFile(path=relative_path)
107-
# assert custom_file.path.is_absolute()
108-
109-
110112
def test_data_eq() -> None:
111113
"""If we use the Data.__eq__ method with an object that is not a pydantic BaseModel, we should return
112114
"NotImplemented".

0 commit comments

Comments
 (0)