Skip to content

ExampleJAXPetab notebook breaks with optax==0.2.7 #3130

@BSnelling

Description

@BSnelling

The ExampleJAXPEtab notebook fails to execute its "Model Training" cell when optax==0.2.7 is installed, e.g. https://github.com/AMICI-dev/AMICI/actions/runs/22305063329/job/64522092897

The error happens when optim.update is called on the model. It is related to Nones in the pytree.

ValueError: Expected None, got JitTracer<int64[1,48,1]>.

In previous releases of JAX, flatten-up-to used to consider None to be a tree-prefix of non-None values. To obtain the previous behavior, you can usually write:
  jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

Metadata

Metadata

Assignees

No one assigned

    Labels

    JAXRelated to the JAX-backend.newNewly created

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions