-
Notifications
You must be signed in to change notification settings - Fork 33
Closed
Labels
Description
The ExampleJAXPEtab notebook fails to execute its "Model Training" cell when optax==0.2.7 is installed, e.g. https://github.com/AMICI-dev/AMICI/actions/runs/22305063329/job/64522092897
The error happens when optim.update is called on the model. It is related to Nones in the pytree.
ValueError: Expected None, got JitTracer<int64[1,48,1]>.
In previous releases of JAX, flatten-up-to used to consider None to be a tree-prefix of non-None values. To obtain the previous behavior, you can usually write:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
Reactions are currently unavailable