diff --git a/python/sdist/amici/exporters/jax/jax.template.py b/python/sdist/amici/exporters/jax/jax.template.py index 37b6339d5a..4b74ae2702 100644 --- a/python/sdist/amici/exporters/jax/jax.template.py +++ b/python/sdist/amici/exporters/jax/jax.template.py @@ -5,7 +5,6 @@ import jax.numpy as jnp import jax.random as jr # noqa: F401 import jaxtyping as jt # noqa: F401 -from interpax import interp1d # noqa: F401 from jax.numpy import inf as oo # noqa: F401 from jax.numpy import nan as nan # noqa: F401 diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index 425ce7d533..334d2f51ab 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -87,12 +87,11 @@ examples = [ "scipy", ] jax = [ - "jax>=0.7.2,<0.8.2", + "jax>=0.7.2", "diffrax>=0.7.0", "jaxtyping>=0.2.34", "equinox>=0.13.2", "optimistix>=0.0.9", - "interpax>=0.3.9", ] sciml = [ "h5py"