Started on Dec 19. Probably related to jax 0.8.2..
________________________ test_preequilibration_failure _________________________
lotka_volterra = <petab.v1.problem.Problem object at 0x7f3ba850b820>
def test_preequilibration_failure(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
# oscillating system, preequilibation should fail when interaction is active
with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
jax_problem = import_petab_problem(
petab_problem, jax=True, output_dir=model_dir
)
> r = run_simulations(jax_problem)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
python/tests/test_jax.py:291:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
python/sdist/amici/jax/petab.py:1599: in run_simulations
preeqs, preresults = problem.run_preequilibrations(
python/sdist/amici/jax/petab.py:1528: in run_preequilibrations
return self.run_preequilibration(
venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py:81: in __call__
return self.func(*self.args, *args, **kwargs, **self.keywords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/_vmap_pmap.py:169: in __call__
vmapd, nonvmapd_arr, static = jax.vmap(
python/sdist/amici/jax/petab.py:1503: in run_preequilibration
return self.model.preequilibrate_condition(
venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py:81: in __call__
return self.func(*self.args, *args, **kwargs, **self.keywords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
python/sdist/amici/jax/model.py:916: in preequilibrate_condition
current_x, _, stats_preeq = eq(
python/sdist/amici/jax/_simulation.py:76: in eq
sol, _, stats = _run_segment(
python/sdist/amici/jax/_simulation.py:368: in _run_segment
sol = diffrax.diffeqsolve(
venv/lib/python3.13/site-packages/diffrax/_integrate.py:1297: in diffeqsolve
_, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py:1151: in step
final_val = eqxi.while_loop(
venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py:107: in while_loop
return checkpointed_while_loop(
venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:247: in checkpointed_while_loop
body_fun_ = filter_closure_convert(body_fun_, init_val_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py:511: in new_body_fun
buffer_val2 = body_fun(buffer_val)
^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py:986: in rk_stage
nonlinear_sol = optx.root_find(
venv/lib/python3.13/site-packages/optimistix/_root_find.py:218: in root_find
return iterative_solve(
venv/lib/python3.13/site-packages/optimistix/_iterate.py:344: in iterative_solve
) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_adjoint.py:133: in apply
return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_ad.py:60: in implicit_jvp
root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_ad.py:67: in _implicit_impl
return jtu.tree_map(jnp.asarray, fn_primal(inputs))
^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_iterate.py:240: in _iterate
final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py:103: in while_loop
_, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py:511: in new_body_fun
buffer_val2 = body_fun(buffer_val)
^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_iterate.py:230: in body_fun
new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/diffrax/_root_finder/_verychord.py:127: in step
sol = lx.linear_solve(
venv/lib/python3.13/site-packages/lineax/_solve.py:820: in linear_solve
solution, result, stats = eqxi.filter_primitive_bind(
venv/lib/python3.13/site-packages/equinox/internal/_primitive.py:271: in filter_primitive_bind
flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_primitive.py:156: in _wrapper
out = rule(*args)
^^^^^^^^^^^
venv/lib/python3.13/site-packages/lineax/_solve.py:126: in _linear_solve_abstract_eval
out = eqx.filter_eval_shape(
venv/lib/python3.13/site-packages/lineax/_solve.py:87: in _linear_solve_impl
out = solver.compute(state, vector, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/lineax/_solve.py:648: in compute
solution, result, _ = solver.compute(state, vector, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LU()
state = ((JitTracer(float64[2,2]), JitTracer(int32[2])), Static(
_leaves=[
ShapeDtypeStruct(shape=(2,), dtype=float64),
...peStruct(shape=(2,), dtype=float64),
PyTreeDef((*, *))
],
_treedef=PyTreeDef(([*, *], *))
), JitTracer(bool[]))
vector = JitTracer(float64[2])
def compute(
self, state: _LUState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
del options
lu_and_piv, packed_structures, transpose = state
> trans = 1 if transpose else 0
^^^^^^^^^
E jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
E The error occurred while tracing the function _fn at /home/runner/work/AMICI/AMICI/venv/lib/python3.13/site-packages/equinox/_eval_shape.py:31 for jit. This concrete value was not available in Python because it depends on the value of the argument _dynamic[1][1][1][2].
E See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
E --------------------
E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
venv/lib/python3.13/site-packages/lineax/_solver/lu.py:61: TracerBoolConversionError
Various GHA workflows are failing with
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].Started on Dec 19. Probably related to jax 0.8.2..
Last good:
First bad: