From 11a8623e36e6c69f405efd5c42e48215fcd1b373 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 27 Feb 2026 10:00:11 +0100 Subject: [PATCH 1/2] Fix importing PetabImporter without jax extras Closes #3140. --- CHANGELOG.md | 12 ++++++++++-- python/sdist/amici/importers/petab/__init__.py | 8 +------- .../sdist/amici/importers/petab/_petab_importer.py | 6 ++++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c986ad837..cae5ff7e4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,15 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni ## v1.X Series -### v1.0.0 (unreleased) +### v1.0.1 + +**Fixes** + +* Fixed an issue that resulted in failure to import the `PetabImporter` if + the jax-dependencies weren't installed. + + +### v1.0.0 **BREAKING CHANGES** @@ -39,7 +47,7 @@ The following functionality has been removed without replacement: fixed parameters as "fixed parameters", "constant parameters", or "constants". This has now been harmonized to "free" and "fixed" across the API. E.g., `Model.setParameters()` is now `Model.set_free_parameters()`. -* `ReturnDataView.posteq_numsteps` and `ReturnDataView.posteq_numsteps` now +* `ReturnDataView.posteq_numsteps` and `ReturnDataView.preeq_numsteps` now return a one-dimensional array of shape `(num_timepoints,)` instead of a two-dimensional array of shape `(1, num_timepoints)`. * `ReturnDataView.posteq_status` and `ReturnDataView.preeq_status` now diff --git a/python/sdist/amici/importers/petab/__init__.py b/python/sdist/amici/importers/petab/__init__.py index afc36b961f..65157a4ab3 100644 --- a/python/sdist/amici/importers/petab/__init__.py +++ b/python/sdist/amici/importers/petab/__init__.py @@ -21,10 +21,4 @@ passing a :class:`petab.v1.Problem` instance to the PEtab v2 import functions. """ -# FIXME: for some tests (petab-sciml, maybe petab-v1-pysb) we still rely on an -# old PEtab version on which the petab v2 import does not work. -# Once those tests are updated, we can remove this try-except block. -try: - from ._petab_importer import * # noqa: F403, F401 -except ImportError: - pass +from ._petab_importer import * # noqa: F403, F401 diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index c5c2853bbd..fcd9d55b3e 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -24,7 +24,6 @@ from amici._symbolic import DEModel, Event from amici.importers.utils import MeasurementChannel, amici_time_symbol from amici.logging import get_logger -from amici.sim.jax.petab import JAXProblem from .v1._sbml_import import _add_global_parameter @@ -594,7 +593,7 @@ def create_model(self) -> amici.sim.sundials.Model: def create_simulator( self, force_import: bool = False - ) -> amici.sim.sundials.petab.PetabSimulator: + ) -> amici.sim.sundials.petab.PetabSimulator | amici.sim.jax.JAXProblem: """ Create a PEtab simulator for the imported model. @@ -607,6 +606,9 @@ def create_simulator( if self._jax: model_module = self.import_module(force_import=force_import) model = model_module.Model() + + from amici.sim.jax.petab import JAXProblem + return JAXProblem(model, self.petab_problem) model = self.import_module(force_import=force_import).get_model() From cd6c8b35f8419bd17b3811ab8d43181ee27e0210 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 27 Feb 2026 10:08:48 +0100 Subject: [PATCH 2/2] local h5py import --- python/sdist/amici/sim/jax/petab.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/sim/jax/petab.py b/python/sdist/amici/sim/jax/petab.py index 19ee7e681c..a7e97407f8 100644 --- a/python/sdist/amici/sim/jax/petab.py +++ b/python/sdist/amici/sim/jax/petab.py @@ -10,7 +10,6 @@ import diffrax import equinox as eqx -import h5py import jax.lax import jax.numpy as jnp import jaxtyping as jt @@ -593,6 +592,9 @@ def _load_parameter_arrays_from_files(self) -> dict: "array_files", [] ) + import h5py + + # TODO(performance): Avoid opening each file multiple times return { file_spec.split("_")[0]: h5py.File(file_spec, "r")["parameters"][ file_spec.split("_")[0] @@ -615,6 +617,9 @@ def _load_input_arrays_from_files(self) -> dict: "array_files", [] ) + import h5py + + # TODO(performance): Avoid opening each file multiple times return { file_spec.split("_")[0]: h5py.File(file_spec, "r")["inputs"] for file_spec in array_files