Skip to content

Add support in JAX backend for PEtab v2#3115

Open
BSnelling wants to merge 41 commits intomainfrom
bes/petabv2_jax
Open

Add support in JAX backend for PEtab v2#3115
BSnelling wants to merge 41 commits intomainfrom
bes/petabv2_jax

Conversation

@BSnelling
Copy link
Collaborator

No description provided.

@codecov
Copy link

codecov bot commented Jan 9, 2026

Codecov Report

❌ Patch coverage is 83.33333% with 40 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.32%. Comparing base (126e936) to head (a1aefbc).

Files with missing lines Patch % Lines
python/sdist/amici/jax/petab.py 73.57% 37 Missing ⚠️
python/sdist/amici/sim/jax/__init__.py 96.07% 2 Missing ⚠️
python/sdist/amici/jax/model.py 90.90% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #3115      +/-   ##
==========================================
- Coverage   78.24%   77.32%   -0.92%     
==========================================
  Files         315      316       +1     
  Lines       20482    20648     +166     
  Branches     1484     1483       -1     
==========================================
- Hits        16027    15967      -60     
- Misses       4447     4673     +226     
  Partials        8        8              
Flag Coverage Δ
cpp 70.39% <24.05%> (-1.32%) ⬇️
cpp_python 37.85% <7.59%> (-0.01%) ⬇️
petab 47.14% <83.33%> (+0.58%) ⬆️
petab_sciml ?
python 68.85% <24.05%> (-1.52%) ⬇️
sbmlsuite-jax 32.76% <25.00%> (-0.65%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/sdist/amici/_symbolic/de_model.py 91.68% <100.00%> (-0.24%) ⬇️
...ython/sdist/amici/_symbolic/de_model_components.py 93.25% <100.00%> (+0.15%) ⬆️
...hon/sdist/amici/importers/petab/_petab_importer.py 86.74% <ø> (-0.04%) ⬇️
...dist/amici/importers/petab/v1/parameter_mapping.py 77.83% <100.00%> (-7.55%) ⬇️
python/sdist/amici/importers/pysb/__init__.py 94.57% <ø> (-0.07%) ⬇️
python/sdist/amici/importers/sbml/__init__.py 94.26% <ø> (-0.10%) ⬇️
python/sdist/amici/jax/_simulation.py 94.44% <100.00%> (-2.62%) ⬇️
python/sdist/amici/jax/ode_export.py 88.05% <100.00%> (-1.18%) ⬇️
python/sdist/amici/jax/model.py 84.54% <90.90%> (-2.96%) ⬇️
python/sdist/amici/sim/jax/__init__.py 96.07% <96.07%> (ø)
... and 1 more

... and 5 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@BSnelling BSnelling marked this pull request as ready for review January 26, 2026 12:31
@BSnelling BSnelling requested a review from a team as a code owner January 26, 2026 12:31
@BSnelling BSnelling requested a review from FFroehlich January 26, 2026 12:31
@sonarqubecloud
Copy link

@dweindl
Copy link
Member

dweindl commented Feb 3, 2026

Sorry for the delay, I'll try to review tomorrow or the day after.

Copy link
Member

@dweindl dweindl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks. Here some initial comments. Didn't manage to go over everything yet.

if (s.startswith("_petab_") and "indicator" in s) or s == "t":
continue
implicit_symbols.append(s)
return len(implicit_symbols) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sub-package should stay PEtab-independent. I think a better way to handle that would be passing all indicator variables as fixed_parameters to SbmlImporter._build_ode_model via SbmlImporter.sbml2jax and then use the same logic as for the non-jax event handling.

self._initialize_model_with_nominal_values(model)
)
self._parameter_mappings = self._get_parameter_mappings(scs)
if isinstance(petab_problem, petabv1.Problem):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge with the if isinstance above?

"""
scs = petab_problem.get_simulation_conditions_from_measurement_df()
self.simulation_conditions = tuple(tuple(sc) for sc in scs.values)
if isinstance(petab_problem, petabv2.Problem):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to support both, petab v1 and v2, or is this just until petab-sciml is updated?

Copy link
Collaborator Author

@BSnelling BSnelling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review @dweindl! I have implemented most points, but I have a couple of outstanding questions, if you could advise on those.

")"
],
"id": "3f2ab1acb3ba818f"
"amici_model.simulate(petab_problem.get_x_nominal_dict())"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the notebook runs there is an error from this cell that the FIM was not computed: https://github.com/AMICI-dev/AMICI/actions/runs/21985309084/job/63517900881?pr=3115

@dweindl can you advise on that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault. I hope #3125 fixes that.

"""Check whether the event has implicit triggers.
"""
t = self.get_val()
return not t.free_symbols.issubset(allowed_symbols)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the PEtab related references from this function and am passing the indicator variables as fixed_parameters. I wanted to include this logic in the has_explicit_trigger_times function but in all the JAX examples, self._t_root is not set. Should _t_root be set during the import of petab events?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, okay. I think this might be a general issue, not just with JAX. My guess is that sympy isn't able to solve those petab event trigger piecewise expressions for $t$ (here

try:
self._t_root = sp.solve(self.get_val(), amici_time_symbol)
except NotImplementedError:
# the trigger can't be solved for `t`
pass
). That should be feasible, but I'm not sure how easy it is to implement. I won't be able to look into that for at least another week. Okay to keep your previous approach for now then. Please open an issue to address that later.

Copy link
Collaborator Author

@BSnelling BSnelling Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, sympy is failing to solve these expressions. Opened #3126

self.simulation_conditions = tuple(tuple(sc) for sc in scs.values)
if isinstance(petab_problem, petabv1.Problem):
raise TypeError(
"JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is just to support v2, so I've added this error if a v1 problem is encountered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Can you please replace all the petab.v1 constants by those in petab.v2 to make it clear that we don't want to deal with v1 here?
It seems there is still quite a bit of v1 code left here. Probably all occurrences of simulationConditionId/SIMULATION_CONDITION_ID. Is that still needed for petab-sciml? If not please remove.

Copy link
Member

@dweindl dweindl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @BSnelling. Here a few more comments. Okay for me to merge once all tests pass and to address the remaining issues together with the petab-sciml update if that's more convenient.

Comment on lines +88 to +104
petabv1.OBSERVABLE_ID: obs,
petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t),
petabv1.TIME: t[problem._ts_masks[ic, :]],
petabv1.SIMULATION: y[ic, problem._ts_masks[ic, :]],
},
index=problem._petab_measurement_indices[ic, :],
)
if (
petabv1.OBSERVABLE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'"
)[petabv2.C.OBSERVABLE_PARAMETERS]
)
if petabv1.NOISE_PARAMETERS in problem._petab_problem.measurement_df:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
petabv1.OBSERVABLE_ID: obs,
petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t),
petabv1.TIME: t[problem._ts_masks[ic, :]],
petabv1.SIMULATION: y[ic, problem._ts_masks[ic, :]],
},
index=problem._petab_measurement_indices[ic, :],
)
if (
petabv1.OBSERVABLE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'"
)[petabv2.C.OBSERVABLE_PARAMETERS]
)
if petabv1.NOISE_PARAMETERS in problem._petab_problem.measurement_df:
petabv2.OBSERVABLE_ID: obs,
petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t),
petabv2.TIME: t[problem._ts_masks[ic, :]],
petabv2.SIMULATION: y[ic, problem._ts_masks[ic, :]],
},
index=problem._petab_measurement_indices[ic, :],
)
if (
petabv2.OBSERVABLE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'"
)[petabv2.C.OBSERVABLE_PARAMETERS]
)
if petabv2.NOISE_PARAMETERS in problem._petab_problem.measurement_df:

No effect, but since the function is specifically for petab v2, let's use the constants from there.

self.simulation_conditions = tuple(tuple(sc) for sc in scs.values)
if isinstance(petab_problem, petabv1.Problem):
raise TypeError(
"JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Can you please replace all the petab.v1 constants by those in petab.v2 to make it clear that we don't want to deal with v1 here?
It seems there is still quite a bit of v1 code left here. Probably all occurrences of simulationConditionId/SIMULATION_CONDITION_ID. Is that still needed for petab-sciml? If not please remove.

Comment on lines 376 to 378
self._petab_problem.observable_df.loc[
oid, petab.OBSERVABLE_TRANSFORMATION
oid, petabv1.OBSERVABLE_TRANSFORMATION
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general comment: With petab.v2.Problem, all .*_df properties construct a new DataFrame on each call. Try to avoid accessing those properties inside loops.

In many cases it might be more convenient to directly use the underlying objects instead of the DataFrames (e.g., Problem.observables instead of Problem.observable_df).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants