Skip to content

jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[] #3113

@dweindl

Description

@dweindl

Various GHA workflows are failing with jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].

Started on Dec 19. Probably related to jax 0.8.2..

Last good:

Successfully installed diffrax-0.7.0 equinox-0.13.2 jax-0.8.1 jaxlib-0.8.1 jaxtyping-0.3.4 lineax-0.0.8 ml_dtypes-0.5.4 opt_einsum-3.4.0 optimistix-0.0.11 wadler-lindig-0.1.7

First bad:

Successfully installed diffrax-0.7.0 equinox-0.13.2 jax-0.8.2 jaxlib-0.8.2 jaxtyping-0.3.4 lineax-0.0.8 ml_dtypes-0.5.4 opt_einsum-3.4.0 optimistix-0.0.11 wadler-lindig-0.1.7

________________________ test_preequilibration_failure _________________________

lotka_volterra = <petab.v1.problem.Problem object at 0x7f3ba850b820>

    def test_preequilibration_failure(lotka_volterra):  # noqa: F811
        petab_problem = lotka_volterra
        # oscillating system, preequilibation should fail when interaction is active
        with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
            jax_problem = import_petab_problem(
                petab_problem, jax=True, output_dir=model_dir
            )
>           r = run_simulations(jax_problem)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

python/tests/test_jax.py:291: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
python/sdist/amici/jax/petab.py:1599: in run_simulations
    preeqs, preresults = problem.run_preequilibrations(
python/sdist/amici/jax/petab.py:1528: in run_preequilibrations
    return self.run_preequilibration(
venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py:81: in __call__
    return self.func(*self.args, *args, **kwargs, **self.keywords)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/_vmap_pmap.py:169: in __call__
    vmapd, nonvmapd_arr, static = jax.vmap(
python/sdist/amici/jax/petab.py:1503: in run_preequilibration
    return self.model.preequilibrate_condition(
venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py:81: in __call__
    return self.func(*self.args, *args, **kwargs, **self.keywords)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
python/sdist/amici/jax/model.py:916: in preequilibrate_condition
    current_x, _, stats_preeq = eq(
python/sdist/amici/jax/_simulation.py:76: in eq
    sol, _, stats = _run_segment(
python/sdist/amici/jax/_simulation.py:368: in _run_segment
    sol = diffrax.diffeqsolve(
venv/lib/python3.13/site-packages/diffrax/_integrate.py:1297: in diffeqsolve
    _, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py:1151: in step
    final_val = eqxi.while_loop(
venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py:107: in while_loop
    return checkpointed_while_loop(
venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:247: in checkpointed_while_loop
    body_fun_ = filter_closure_convert(body_fun_, init_val_)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py:511: in new_body_fun
    buffer_val2 = body_fun(buffer_val)
                  ^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py:986: in rk_stage
    nonlinear_sol = optx.root_find(
venv/lib/python3.13/site-packages/optimistix/_root_find.py:218: in root_find
    return iterative_solve(
venv/lib/python3.13/site-packages/optimistix/_iterate.py:344: in iterative_solve
    ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_adjoint.py:133: in apply
    return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_ad.py:60: in implicit_jvp
    root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_ad.py:67: in _implicit_impl
    return jtu.tree_map(jnp.asarray, fn_primal(inputs))
                                     ^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_iterate.py:240: in _iterate
    final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py:103: in while_loop
    _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py:511: in new_body_fun
    buffer_val2 = body_fun(buffer_val)
                  ^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/optimistix/_iterate.py:230: in body_fun
    new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/diffrax/_root_finder/_verychord.py:127: in step
    sol = lx.linear_solve(
venv/lib/python3.13/site-packages/lineax/_solve.py:820: in linear_solve
    solution, result, stats = eqxi.filter_primitive_bind(
venv/lib/python3.13/site-packages/equinox/internal/_primitive.py:271: in filter_primitive_bind
    flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/equinox/internal/_primitive.py:156: in _wrapper
    out = rule(*args)
          ^^^^^^^^^^^
venv/lib/python3.13/site-packages/lineax/_solve.py:126: in _linear_solve_abstract_eval
    out = eqx.filter_eval_shape(
venv/lib/python3.13/site-packages/lineax/_solve.py:87: in _linear_solve_impl
    out = solver.compute(state, vector, options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
venv/lib/python3.13/site-packages/lineax/_solve.py:648: in compute
    solution, result, _ = solver.compute(state, vector, options)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = LU()
state = ((JitTracer(float64[2,2]), JitTracer(int32[2])), Static(
  _leaves=[
    ShapeDtypeStruct(shape=(2,), dtype=float64),
...peStruct(shape=(2,), dtype=float64),
    PyTreeDef((*, *))
  ],
  _treedef=PyTreeDef(([*, *], *))
), JitTracer(bool[]))
vector = JitTracer(float64[2])

    def compute(
        self, state: _LUState, vector: PyTree[Array], options: dict[str, Any]
    ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
        del options
        lu_and_piv, packed_structures, transpose = state
>       trans = 1 if transpose else 0
                     ^^^^^^^^^
E       jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
E       The error occurred while tracing the function _fn at /home/runner/work/AMICI/AMICI/venv/lib/python3.13/site-packages/equinox/_eval_shape.py:31 for jit. This concrete value was not available in Python because it depends on the value of the argument _dynamic[1][1][1][2].
E       See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

venv/lib/python3.13/site-packages/lineax/_solver/lu.py:61: TracerBoolConversionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    JAXRelated to the JAX-backend.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions