diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c78249a9c..da85e02a6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.14.1 + rev: v0.15.2 hooks: # Run the linter. - id: ruff diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 8aa025bfaa..a12b4d92fa 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -31,7 +31,6 @@ "metadata": {}, "outputs": [], "source": [ - "import petab.v1 as petab\n", "from amici.importers.petab import *\n", "from petab.v2 import Problem\n", "\n", @@ -403,7 +402,9 @@ "nps = jax_problem._np_numeric[ic, :]\n", "\n", "# Load parameters for the specified condition\n", - "p = jax_problem.load_model_parameters(jax_problem._petab_problem.experiments[0], is_preeq=False)\n", + "p = jax_problem.load_model_parameters(\n", + " jax_problem._petab_problem.experiments[0], is_preeq=False\n", + ")\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -576,7 +577,6 @@ "outputs": [], "source": [ "from amici.sim.sundials import SensitivityMethod, SensitivityOrder\n", - "from amici.sim.sundials.petab.v1 import simulate_petab\n", "\n", "# Import the PEtab problem as a standard AMICI model\n", "pi = PetabImporter(\n", diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index e9fef47edc..d6e96c708c 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2648,7 +2648,7 @@ def get_implicit_roots(self) -> list[sp.Expr]: for e in self._events if not e.has_explicit_trigger_times() ] - + def has_algebraic_states(self) -> bool: """ Checks whether the model has algebraic states @@ -2666,7 +2666,7 @@ def has_event_assignments(self) -> bool: boolean indicating if event assignments are present """ return any(event.updates_state for event in self._events) - + def has_priority_events(self) -> bool: """ Checks whether the model has events with priorities defined @@ -2675,7 +2675,7 @@ def has_priority_events(self) -> bool: boolean indicating if priority events are present """ return any(event.get_priority() is not None for event in self._events) - + def has_implicit_event_assignments(self) -> bool: """ Checks whether the model has event assignments with implicit triggers @@ -2686,9 +2686,13 @@ def has_implicit_event_assignments(self) -> bool: """ fixed_symbols = set([k._symbol for k in self._fixed_parameters]) allowed_symbols = fixed_symbols | {amici_time_symbol} - # TODO: update to use has_explicit_trigger_times once + # TODO: update to use has_explicit_trigger_times once # https://github.com/AMICI-dev/AMICI/issues/3126 is resolved - return any(event.updates_state and event._has_implicit_triggers(allowed_symbols) for event in self._events) + return any( + event.updates_state + and event._has_implicit_triggers(allowed_symbols) + for event in self._events + ) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/_symbolic/de_model_components.py b/python/sdist/amici/_symbolic/de_model_components.py index 76b6c1f501..ecce215cf2 100644 --- a/python/sdist/amici/_symbolic/de_model_components.py +++ b/python/sdist/amici/_symbolic/de_model_components.py @@ -740,7 +740,7 @@ def __init__( :param priority: The priority of the event assignment. :param is_negative_event: - Whether this event is a "negative" event, i.e., an event that is + Whether this event is a "negative" event, i.e., an event that is added to mirror an existing event with inverted trigger condition to avoid immediate retriggering of the original event (JAX simulations). @@ -863,17 +863,16 @@ def has_explicit_trigger_times( """ if allowed_symbols is None: return len(self._t_root) > 0 - + return len(self._t_root) > 0 and all( t.is_Number or t.free_symbols.issubset(allowed_symbols) for t in self._t_root ) - + def _has_implicit_triggers( self, allowed_symbols: set[sp.Symbol] | None = None ) -> bool: - """Check whether the event has implicit triggers. - """ + """Check whether the event has implicit triggers.""" t = self.get_val() return not t.free_symbols.issubset(allowed_symbols) diff --git a/python/sdist/amici/constants.py b/python/sdist/amici/constants.py index 1ee46111c4..b751b7a84f 100644 --- a/python/sdist/amici/constants.py +++ b/python/sdist/amici/constants.py @@ -8,7 +8,7 @@ import enum -class SymbolId(str, enum.Enum): +class SymbolId(enum.StrEnum): """ Defines the different fields in the symbol dict to which sbml entities get parsed to. diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index e49d73d668..bcb8633c51 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -23,8 +23,8 @@ from amici import get_model_dir from amici._symbolic import DEModel, Event from amici.importers.utils import MeasurementChannel, amici_time_symbol -from amici.logging import get_logger from amici.jax.petab import JAXProblem +from amici.logging import get_logger from .v1.sbml_import import _add_global_parameter @@ -608,7 +608,7 @@ def create_simulator( model_module = self.import_module(force_import=force_import) model = model_module.Model() return JAXProblem(model, self.petab_problem) - + model = self.import_module(force_import=force_import).get_model() em = ExperimentManager(model=model, petab_problem=self.petab_problem) return PetabSimulator(em=em) diff --git a/python/sdist/amici/importers/pysb/__init__.py b/python/sdist/amici/importers/pysb/__init__.py index 75a2bae0cd..38c5ac2d9b 100644 --- a/python/sdist/amici/importers/pysb/__init__.py +++ b/python/sdist/amici/importers/pysb/__init__.py @@ -10,7 +10,6 @@ import itertools import logging import os -import re import sys from collections.abc import Callable, Iterable from pathlib import Path @@ -31,9 +30,7 @@ FixedParameter, FreeParameter, LogLikelihoodY, - NoiseParameter, Observable, - ObservableParameter, SigmaY, ) from amici.logging import get_logger, log_execution_time, set_log_level diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 93111651b2..5584ba5657 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -1902,7 +1902,7 @@ def _process_events(self) -> None: if self.jax: # Add a negative event for JAX models to handle - # TODO: remove once condition function directions can be + # TODO: remove once condition function directions can be # traced through diffrax solve neg_event_id = event_id + "_negative" neg_event_sym = sp.Symbol(neg_event_id) diff --git a/python/sdist/amici/importers/sbml/splines.py b/python/sdist/amici/importers/sbml/splines.py index 9a6a99157a..c618cc57b7 100644 --- a/python/sdist/amici/importers/sbml/splines.py +++ b/python/sdist/amici/importers/sbml/splines.py @@ -1511,10 +1511,9 @@ def _spline_user_functions( "AmiciSplineSensitivity": [ ( lambda *args: True, - lambda spline_id, - x, - param_id, - *p: f"sspl_{spline_ids.index(spline_id)}_{p_index[param_id]}", + lambda spline_id, x, param_id, *p: ( + f"sspl_{spline_ids.index(spline_id)}_{p_index[param_id]}" + ), ) ], } diff --git a/python/sdist/amici/importers/utils.py b/python/sdist/amici/importers/utils.py index 5b1026b612..8672902048 100644 --- a/python/sdist/amici/importers/utils.py +++ b/python/sdist/amici/importers/utils.py @@ -71,7 +71,7 @@ def __init__(self, data): annotation_namespace = "https://github.com/AMICI-dev/AMICI" -class ObservableTransformation(str, enum.Enum): +class ObservableTransformation(enum.StrEnum): """ Different modes of observable transformation. """ diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index b20748d098..95ef034ace 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -491,6 +491,7 @@ def _handle_event( return y0_next, h_next, stats + def _check_cascading_events( t0_next: float, y0_next: jt.Float[jt.Array, "nxs"], @@ -521,6 +522,7 @@ def _check_cascading_events( return y0_next + def _apply_event_assignments( roots_found, roots_dir, @@ -543,7 +545,7 @@ def _apply_event_assignments( ] ).T - # apply one event at a time + # apply one event at a time if h_next.shape[0] and y0_next.shape[0]: n_pairs = h_next.shape[0] // 2 inds_seq = jnp.arange(n_pairs) diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index d50f44c7a2..063adc5045 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -127,18 +127,18 @@ def _root_cond_fn(self, t, y, args, **_): TPL_EROOT_EQ return jnp.hstack((TPL_IROOT_RET, TPL_EROOT_RET)) - + def _delta_x(self, y, p, tcl): TPL_X_SYMS = y TPL_ALL_P_SYMS = p TPL_TCL_SYMS = tcl # FIXME: workaround until state from event time is properly passed TPL_X_OLD_SYMS = y - + TPL_DELTAX_EQ return TPL_DELTAX_RET - + @property def event_initial_values(self): return TPL_EVENT_INITIAL_VALUES diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index ed072e3d2f..b2274e1fb1 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -3,6 +3,7 @@ # ruff: noqa: F821 F722 import enum +import os from abc import abstractmethod from collections.abc import Callable from dataclasses import field @@ -15,9 +16,7 @@ import jaxtyping as jt from optimistix import AbstractRootFinder -import os - -from ._simulation import eq, solve, _apply_event_assignments +from ._simulation import _apply_event_assignments, eq, solve class ReturnValue(enum.Enum): @@ -363,25 +362,20 @@ def expression_ids(self) -> list[str]: ... def _root_cond_fn_event( - self, - ie: int, - t: float, - y: jt.Float[jt.Array, "nxs"], - args: tuple, - **_ - ): + self, ie: int, t: float, y: jt.Float[jt.Array, "nxs"], args: tuple, **_ + ): """ Root condition function for a specific event index. - :param ie: + :param ie: event index - :param t: + :param t: time point - :param y: + :param y: state vector - :param args: + :param args: tuple of arguments required for _root_cond_fn - :return: + :return: mask of root condition value for the specified event index """ __, __, h = args @@ -390,7 +384,9 @@ def _root_cond_fn_event( masked_rval = jnp.where(h == 0.0, rval, 1.0) return masked_rval.at[ie].get() - def _root_cond_fns(self) -> list[Callable[[float, jt.Float[jt.Array, "nxs"], tuple], jt.Float]]: + def _root_cond_fns( + self, + ) -> list[Callable[[float, jt.Float[jt.Array, "nxs"], tuple], jt.Float]]: """Return condition functions for implicit discontinuities. These functions are passed to :class:`diffrax.Event` and must evaluate @@ -429,15 +425,15 @@ def _initialise_heaviside_variables( """ h0 = self.event_initial_values.astype(float) if os.getenv("JAX_DEBUG") == "1": - jax.debug.print( - "h0: {}", - h0, - ) + jax.debug.print( + "h0: {}", + h0, + ) roots_found = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) return jnp.where( - jnp.logical_and(roots_found >= 0.0, h0 == 1.0), - jnp.ones_like(h0), - jnp.zeros_like(h0) + jnp.logical_and(roots_found >= 0.0, h0 == 1.0), + jnp.ones_like(h0), + jnp.zeros_like(h0), ) def _x_rdatas( @@ -526,9 +522,9 @@ def _ys( observables """ return jax.vmap( - lambda t, x, p, tcl, h, iy, op: self._y(t, x, p, tcl, h, op) - .at[iy] - .get(), + lambda t, x, p, tcl, h, iy, op: ( + self._y(t, x, p, tcl, h, op).at[iy].get() + ), in_axes=(0, 0, None, None, 0, 0, 0), )(ts, xs, p, tcl, hs, iys, ops) @@ -566,11 +562,9 @@ def _sigmays( standard deviations of the observables """ return jax.vmap( - lambda t, x, p, tcl, h, iy, op, np: self._sigmay( - self._y(t, x, p, tcl, h, op), p, np - ) - .at[iy] - .get(), + lambda t, x, p, tcl, h, iy, op, np: ( + self._sigmay(self._y(t, x, p, tcl, h, op), p, np).at[iy].get() + ), in_axes=(0, 0, None, None, 0, 0, 0, 0), )(ts, xs, p, tcl, hs, iys, ops, nps) @@ -638,7 +632,7 @@ def simulate_condition_unjitted( x_solver, _, h, _ = self._handle_t0_event( t0, x_solver, - p, + p, tcl, root_finder, self._root_cond_fn, @@ -740,12 +734,14 @@ def simulate_condition_unjitted( output = tcl elif ret in (ReturnValue.res, ReturnValue.chi2): obs_trafo = jax.vmap( - lambda y, iy_trafo: jnp.array( - # needs to follow order in amici.jax.petab.SCALE_TO_INT - [y, safe_log(y), safe_log(y) / jnp.log(10)] - ) - .at[iy_trafo] - .get(), + lambda y, iy_trafo: ( + jnp.array( + # needs to follow order in amici.jax.petab.SCALE_TO_INT + [y, safe_log(y), safe_log(y) / jnp.log(10)] + ) + .at[iy_trafo] + .get() + ), ) ys_obj = obs_trafo( self._ys(ts, x, p, tcl, hs, iys, ops), iy_trafos @@ -845,7 +841,7 @@ def simulate_condition( mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. :param h_mask: - mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it + mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it it marked as 1.0. :param ret: which output to return. See :class:`ReturnValue` for available options. @@ -905,7 +901,7 @@ def preequilibrate_condition( :param mask_reinit: mask for re-initialization. If `True`, the corresponding state variable is re-initialized. :param h_mask: - mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it + mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it it marked as 1.0. :param solver: ODE solver @@ -979,11 +975,11 @@ def _handle_t0_event( ): y0 = y0_next.copy() rf0 = self.event_initial_values - 0.5 - + if h_preeq.shape[0]: # return immediately because preequilibration is equivalent to handling t0 event? return y0, t0_next, h_preeq, stats - else: + else: h = jnp.where(h_mask, jnp.heaviside(rf0, 0.0), jnp.ones_like(rf0)) args = (p, tcl, h) rfx = root_cond_fn(t0_next, y0_next, args) @@ -1005,11 +1001,11 @@ def _handle_t0_event( ) droot_dt = ( # ∂root_cond_fn/∂t - jax.jacfwd(root_cond_fn, argnums=0)(t0_next, y0_next, args) - + - # ∂root_cond_fn/∂y * ∂y/∂t - jax.jacfwd(root_cond_fn, argnums=1)(t0_next, y0_next, args) - @ self._xdot(t0_next, y0_next, args) + jax.jacfwd(root_cond_fn, argnums=0)(t0_next, y0_next, args) + + + # ∂root_cond_fn/∂y * ∂y/∂t + jax.jacfwd(root_cond_fn, argnums=1)(t0_next, y0_next, args) + @ self._xdot(t0_next, y0_next, args) ) h_next = jnp.where( roots_zero, @@ -1030,6 +1026,7 @@ def _handle_t0_event( return y0_next, t0_next, h_next, stats + def safe_log(x: jnp.float_) -> jnp.float_: """ Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0. diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 3cbe2a34b2..7d69fa3622 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -17,7 +17,6 @@ from pathlib import Path import sympy as sp -import numpy as np from amici import ( amiciModulePath, @@ -62,7 +61,7 @@ def _jax_variable_equations( code_printer._get_sym_lines( (s.name for s in model.sym(eq_name)), # sp.Matrix to support event assignments which are lists - sp.Matrix(model.eq(eq_name)).subs(subs), + sp.Matrix(model.eq(eq_name)).subs(subs), indent, ) )[indent:] # remove indent for first line @@ -155,7 +154,7 @@ def __init__( raise NotImplementedError( "The JAX backend does not support event priorities." ) - + if ode_model.has_implicit_event_assignments(): raise NotImplementedError( "The JAX backend does not support event assignments with implicit triggers." @@ -262,11 +261,19 @@ def _generate_jax_code(self) -> None: # tuple of variable names (ids as they are unique) **_jax_variable_ids(self.model, ("p", "k", "y", "w", "x_rdata")), "P_VALUES": _jnp_array_str(self.model.val("p")), - "ALL_P_VALUES": _jnp_array_str(self.model.val("p") + self.model.val("k")), - "ALL_P_IDS": "".join(f'"{s.name}", ' for s in self._get_all_p_syms()) - if self._get_all_p_syms() else "tuple()", - "ALL_P_SYMS": "".join(f"{s.name}, " for s in self._get_all_p_syms()) - if self._get_all_p_syms() else "_", + "ALL_P_VALUES": _jnp_array_str( + self.model.val("p") + self.model.val("k") + ), + "ALL_P_IDS": "".join( + f'"{s.name}", ' for s in self._get_all_p_syms() + ) + if self._get_all_p_syms() + else "tuple()", + "ALL_P_SYMS": "".join( + f"{s.name}, " for s in self._get_all_p_syms() + ) + if self._get_all_p_syms() + else "_", "ROOTS": _jnp_array_str( { _print_trigger_root(root) @@ -277,9 +284,7 @@ def _generate_jax_code(self) -> None: "N_IEVENTS": str(len(self.model.get_implicit_roots())), "N_EEVENTS": str(len(self.model.get_explicit_roots())), "EVENT_INITIAL_VALUES": _jnp_array_str( - [ - e.get_initial_value() for e in self.model._events - ] + [e.get_initial_value() for e in self.model._events] ), **{ "MODEL_NAME": self.model_name, @@ -360,6 +365,7 @@ def set_name(self, model_name: str) -> None: self.model_name = model_name + def _print_trigger_root(root: sp.Expr) -> str: """Convert a trigger root expression into a string representation. diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 3520701fe6..ae2172c81d 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1,14 +1,13 @@ """PEtab wrappers for JAX models.""" "" -import copy import logging +import os import re import shutil from collections.abc import Callable, Iterable, Sized from numbers import Number from pathlib import Path -import os import diffrax import equinox as eqx import h5py @@ -25,12 +24,14 @@ from amici import _module_from_path from amici.importers.petab.v1.parameter_mapping import ( ParameterMappingForCondition, - create_parameter_mapping, ) from amici.jax.model import JAXModel, ReturnValue from amici.logging import get_logger from amici.sim.jax import ( - add_default_experiment_names_to_v2_problem, get_simulation_conditions_v2, _build_simulation_df_v2, _try_float + _build_simulation_df_v2, + _try_float, + add_default_experiment_names_to_v2_problem, + get_simulation_conditions_v2, ) DEFAULT_CONTROLLER_SETTINGS = { @@ -87,6 +88,7 @@ def __init__(self, petab_problem: petabv1.Problem): self.__dict__.update(petab_problem.__dict__) self.hybridization_df = _get_hybridization_df(petab_problem) + class HybridV2Problem(petabv2.Problem): hybridization_df: pd.DataFrame extensions_config: dict @@ -101,7 +103,7 @@ def __init__(self, petab_problem: petabv2.Problem): def _get_hybridization_df(petab_problem): if not hasattr(petab_problem, "extensions_config"): return None - + if "sciml" in petab_problem.extensions_config: hybridizations = [ pd.read_csv(hf, sep="\t", index_col=0) @@ -113,7 +115,9 @@ def _get_hybridization_df(petab_problem): return hybridization_df -def _get_hybrid_petab_problem(petab_problem: petabv1.Problem | petabv2.Problem): +def _get_hybrid_petab_problem( + petab_problem: petabv1.Problem | petabv2.Problem, +): if isinstance(petab_problem, petabv2.Problem): return HybridV2Problem(petab_problem) return HybridProblem(petab_problem) @@ -154,7 +158,9 @@ class JAXProblem(eqx.Module): _petab_measurement_indices: np.ndarray _petab_problem: petabv1.Problem | HybridProblem | petabv2.Problem - def __init__(self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Problem): + def __init__( + self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Problem + ): """ Initialize a JAXProblem instance with a model and a PEtab problem. @@ -165,9 +171,11 @@ def __init__(self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Pro """ if isinstance(petab_problem, petabv1.Problem): raise TypeError( - "JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2." + "JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2." ) - petab_problem = add_default_experiment_names_to_v2_problem(petab_problem) + petab_problem = add_default_experiment_names_to_v2_problem( + petab_problem + ) scs = get_simulation_conditions_v2(petab_problem) self.simulation_conditions = scs.conditionId.to_list() self._petab_problem = _get_hybrid_petab_problem(petab_problem) @@ -273,7 +281,10 @@ def _get_measurements( petab_indices = dict() n_pars = dict() - for col in [petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS]: + for col in [ + petabv2.C.OBSERVABLE_PARAMETERS, + petabv2.C.NOISE_PARAMETERS, + ]: n_pars[col] = 0 if col in self._petab_problem.measurement_df: if pd.api.types.is_numeric_dtype( @@ -287,17 +298,20 @@ def _get_measurements( self._petab_problem.measurement_df[col] .str.split(petabv2.C.PARAMETER_SEPARATOR) .apply( - lambda x: len(x) - if isinstance(x, Sized) - else 1 - int(pd.isna(x)) + lambda x: ( + len(x) + if isinstance(x, Sized) + else 1 - int(pd.isna(x)) + ) ) .max() ) for _, simulation_condition in simulation_conditions.iterrows(): - if "preequilibration" in simulation_condition[ - petabv2.C.CONDITION_ID - ]: + if ( + "preequilibration" + in simulation_condition[petabv2.C.CONDITION_ID] + ): continue if isinstance(self._petab_problem, HybridV2Problem): @@ -331,7 +345,10 @@ def _get_measurements( for oid in m[petabv2.C.OBSERVABLE_ID].values ] ) - if petabv2.C.NOISE_DISTRIBUTION in self._petab_problem.observable_df: + if ( + petabv2.C.NOISE_DISTRIBUTION + in self._petab_problem.observable_df + ): iy_trafos = np.array( [ SCALE_TO_INT[petabv2.C.LOG] @@ -359,8 +376,11 @@ def get_parameter_override(x): ] return x - for col in [petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS]: - if col not in m or m[col].isna().all() or all(m[col] == ''): + for col in [ + petabv2.C.OBSERVABLE_PARAMETERS, + petabv2.C.NOISE_PARAMETERS, + ]: + if col not in m or m[col].isna().all() or all(m[col] == ""): mat_numeric = jnp.ones((len(m), n_pars[col])) par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) @@ -369,15 +389,17 @@ def get_parameter_override(x): par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) else: - split_vals = m[col].str.split(petabv2.C.PARAMETER_SEPARATOR) + split_vals = m[col].str.split( + petabv2.C.PARAMETER_SEPARATOR + ) list_vals = split_vals.apply( - lambda x: [get_parameter_override(y) for y in x] - if isinstance(x, list) - else [] - if pd.isna(x) - else [ - x - ] # every string gets transformed to lists, so this is already a float + lambda x: ( + [get_parameter_override(y) for y in x] + if isinstance(x, list) + else [] + if pd.isna(x) + else [x] + ) # every string gets transformed to lists, so this is already a float ) vals = list_vals.apply( lambda x: np.pad( @@ -391,9 +413,11 @@ def get_parameter_override(x): # deconstruct such that we can reconstruct mapped parameter overrides via vectorized operations # mat = np.where(par_mask, map(lambda ip: p.at[ip], par_index), mat_numeric) par_index = np.vectorize( - lambda x: self.parameter_ids.index(x) - if x in self.parameter_ids - else -1 + lambda x: ( + self.parameter_ids.index(x) + if x in self.parameter_ids + else -1 + ) )(mat) # map out numeric values par_mask = par_index != -1 @@ -420,9 +444,13 @@ def get_parameter_override(x): parameter_overrides_par_indices[ petabv2.C.OBSERVABLE_PARAMETERS ], # 7 - parameter_overrides_numeric_vals[petabv2.C.NOISE_PARAMETERS], # 8 + parameter_overrides_numeric_vals[ + petabv2.C.NOISE_PARAMETERS + ], # 8 parameter_overrides_mask[petabv2.C.NOISE_PARAMETERS], # 9 - parameter_overrides_par_indices[petabv2.C.NOISE_PARAMETERS], # 10 + parameter_overrides_par_indices[ + petabv2.C.NOISE_PARAMETERS + ], # 10 ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) @@ -524,15 +552,18 @@ def pad_and_stack(output_index: int): def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: if isinstance(self._petab_problem, HybridV2Problem): - simulation_conditions = ( - get_simulation_conditions_v2(self._petab_problem) + simulation_conditions = get_simulation_conditions_v2( + self._petab_problem + ) + return tuple( + tuple([row.conditionId]) + for _, row in simulation_conditions.iterrows() ) - return tuple(tuple([row.conditionId]) for _, row in simulation_conditions.iterrows()) else: - simulation_conditions = ( - self._petab_problem.get_simulation_conditions_from_measurement_df() + simulation_conditions = self._petab_problem.get_simulation_conditions_from_measurement_df() + return tuple( + tuple(row) for _, row in simulation_conditions.iterrows() ) - return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) def _initialize_model_parameters(self, model: JAXModel) -> dict: """ @@ -811,13 +842,11 @@ def _initialize_model_with_nominal_values( # Create scaled parameter array if isinstance(self._petab_problem, HybridV2Problem): param_map = { - p.id: p.nominal_value - for p in self._petab_problem.parameters + p.id: p.nominal_value for p in self._petab_problem.parameters } - parameter_array = jnp.array([ - float(param_map[pval]) - for pval in self.parameter_ids - ]) + parameter_array = jnp.array( + [float(param_map[pval]) for pval in self.parameter_ids] + ) else: parameter_array = self._create_scaled_parameter_array() @@ -843,9 +872,9 @@ def _get_inputs(self) -> dict: .max(axis=0) + 1 ) - inputs[row["netId"]][row[petabv2.C.MODEL_ENTITY_ID]] = data_flat[ - "value" - ].values.reshape(shape) + inputs[row["netId"]][row[petabv2.C.MODEL_ENTITY_ID]] = ( + data_flat["value"].values.reshape(shape) + ) return inputs @property @@ -856,7 +885,9 @@ def parameter_ids(self) -> list[str]: :return: PEtab parameter ids """ - return self._petab_problem.parameter_df[petabv2.C.ESTIMATE].index.tolist() + return self._petab_problem.parameter_df[ + petabv2.C.ESTIMATE + ].index.tolist() @property def nn_output_ids(self) -> list[str]: @@ -868,7 +899,11 @@ def nn_output_ids(self) -> list[str]: """ if self._petab_problem.mapping_df is None: return [] - if self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID].isnull().all(): + if ( + self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID] + .isnull() + .all() + ): return [] return self._petab_problem.mapping_df[ self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID] @@ -917,9 +952,9 @@ def _is_net_input(model_id): model_id_map = ( self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID].apply( - _is_net_input - ) + self._petab_problem.mapping_df[ + petabv2.C.MODEL_ENTITY_ID + ].apply(_is_net_input) ] .reset_index() .set_index(petabv2.C.MODEL_ENTITY_ID)[petabv2.C.PETAB_ENTITY_ID] @@ -998,7 +1033,8 @@ def _is_net_input(model_id): ] if petab_id in set(self._petab_problem.parameter_df.index) else self._petab_problem.parameter_df.loc[ - hybridization_parameter_map[petab_id], petabv2.C.NOMINAL_VALUE + hybridization_parameter_map[petab_id], + petabv2.C.NOMINAL_VALUE, ] for model_id, petab_id in model_id_map.items() ] @@ -1047,17 +1083,16 @@ def load_model_parameters( for ind, pname in enumerate(self.model.parameter_ids) ] ) - pscale = tuple( - [ - petabv2.C.LIN - for _ in self.model.parameter_ids - ] - ) + pscale = tuple([petabv2.C.LIN for _ in self.model.parameter_ids]) return self._unscale(p, pscale) - + def _map_experiment_model_parameter_value( - self, pname: str, p_index: int, experiment: petabv2.Experiment, is_preeq: bool + self, + pname: str, + p_index: int, + experiment: petabv2.Experiment, + is_preeq: bool, ): """ Get values for the given parameter `pname` from the relevant petab tables. @@ -1084,7 +1119,9 @@ def _map_experiment_model_parameter_value( break init_val = self.model.parameters[p_index] - params_nominals = {p.id: p.nominal_value for p in self._petab_problem.parameters} + params_nominals = { + p.id: p.nominal_value for p in self._petab_problem.parameters + } targets_map = { ch.target_id: ch.target_value for c in self._petab_problem.conditions @@ -1100,13 +1137,20 @@ def _map_experiment_model_parameter_value( ("observable_placeholders", "observable_parameters"), ("noise_placeholders", "noise_parameters"), ): - placeholders = [getattr(o, placeholder_attr) for o in self._petab_problem.observables] + placeholders = [ + getattr(o, placeholder_attr) + for o in self._petab_problem.observables + ] for placeholders in placeholders: - params_list = getattr(self._petab_problem.measurements[0], param_attr) + params_list = getattr( + self._petab_problem.measurements[0], param_attr + ) for i, p in enumerate(placeholders): if str(p) == pname: - val = self._find_val(str(params_list[i]), params_nominals) + val = self._find_val( + str(params_list[i]), params_nominals + ) return val return init_val @@ -1192,7 +1236,9 @@ def _state_reinitialisation_value( ) # only remaining option is nominal value for PEtab parameter # that is not estimated, return nominal value - return self._petab_problem.parameter_df.loc[xval, petabv2.C.NOMINAL_VALUE] + return self._petab_problem.parameter_df.loc[ + xval, petabv2.C.NOMINAL_VALUE + ] def load_reinitialisation( self, @@ -1294,16 +1340,19 @@ def _prepare_experiments( h_mask = jnp.stack( [ - jnp.ones(self.model.n_events) - if (exp_id in exp_ids) + jnp.ones(self.model.n_events) + if (exp_id in exp_ids) else jnp.zeros(self.model.n_events) for exp_id in all_exp_ids ] ) - t_zeros = jnp.stack([ - exp.periods[0].time if exp.periods[0].time >= 0.0 else 0.0 for exp in experiments - ]) + t_zeros = jnp.stack( + [ + exp.periods[0].time if exp.periods[0].time >= 0.0 else 0.0 + for exp in experiments + ] + ) if self.parameters.size: if isinstance(self._petab_problem, HybridV2Problem): @@ -1328,7 +1377,7 @@ def _prepare_experiments( else: unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) - # placeholder values from sundials code may be needed here + # placeholder values from sundials code may be needed here if op_numeric is not None and op_numeric.size: op_array = jnp.where( op_mask, @@ -1363,7 +1412,15 @@ def _prepare_experiments( for sc, p in zip(conditions, p_array) ] ) - return p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros + return ( + p_array, + mask_reinit_array, + x_reinit_array, + op_array, + np_array, + h_mask, + t_zeros, + ) @eqx.filter_vmap( in_axes={ @@ -1516,30 +1573,42 @@ def run_simulations( Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions. """ - simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] - dynamic_conditions = list(sc for sc in simulation_conditions if "preequilibration" not in sc) + simulation_conditions = [ + cid + for exp in experiments + for p in exp.periods + for cid in p.condition_ids + ] + dynamic_conditions = list( + sc for sc in simulation_conditions if "preequilibration" not in sc + ) dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) - p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros = ( - self._prepare_experiments( - experiments, - dynamic_conditions, - False, - self._op_numeric, - self._op_mask, - self._op_indices, - self._np_numeric, - self._np_mask, - self._np_indices, - ) + ( + p_array, + mask_reinit_array, + x_reinit_array, + op_array, + np_array, + h_mask, + t_zeros, + ) = self._prepare_experiments( + experiments, + dynamic_conditions, + False, + self._op_numeric, + self._op_mask, + self._op_indices, + self._np_numeric, + self._np_mask, + self._np_indices, ) init_override_mask = jnp.stack( [ jnp.array( [ - p - in set(self.model.parameter_ids) + p in set(self.model.parameter_ids) for p in self.model.state_ids ] ) @@ -1550,9 +1619,10 @@ def run_simulations( [ jnp.array( [ - self._eval_nn(p, exp.periods[-1].condition_ids[0]) # TODO: Add mapping of p to eval_nn? - if p - in set(self.model.parameter_ids) + self._eval_nn( + p, exp.periods[-1].condition_ids[0] + ) # TODO: Add mapping of p to eval_nn? + if p in set(self.model.parameter_ids) else 1.0 for p in self.model.state_ids ] @@ -1653,13 +1723,20 @@ def run_preequilibrations( ], max_steps: jnp.int_, ): - simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + simulation_conditions = [ + cid + for exp in experiments + for p in exp.periods + for cid in p.condition_ids + ] preequilibration_conditions = list( {sc for sc in simulation_conditions if "preequilibration" in sc} ) p_array, mask_reinit_array, x_reinit_array, _, _, h_mask, _ = ( - self._prepare_experiments(experiments, preequilibration_conditions, True, None, None) + self._prepare_experiments( + experiments, preequilibration_conditions, True, None, None + ) ) return self.run_preequilibration( p_array, @@ -1696,7 +1773,7 @@ def run_simulations( :param problem: Problem to run simulations for. :param simulation_experiments: - Simulation experiments to run simulations for. This is an iterable of experiment ids. + Simulation experiments to run simulations for. This is an iterable of experiment ids. Default is to run simulations for all experiments. :param solver: ODE solver to use for simulation. @@ -1714,7 +1791,9 @@ def run_simulations( :return: Overall output value and condition specific results and statistics. """ - if isinstance(problem, HybridProblem) or isinstance(problem._petab_problem, petabv1.Problem): + if isinstance(problem, HybridProblem) or isinstance( + problem._petab_problem, petabv1.Problem + ): raise TypeError( "run_simulations does not support PEtab v1 problems. Upgrade the problem to PEtab v2." ) @@ -1725,10 +1804,21 @@ def run_simulations( if simulation_experiments is None: experiments = problem._petab_problem.experiments else: - experiments = [exp for exp in problem._petab_problem.experiments if exp.id in simulation_experiments] + experiments = [ + exp + for exp in problem._petab_problem.experiments + if exp.id in simulation_experiments + ] - simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] - dynamic_conditions = list(sc for sc in simulation_conditions if "preequilibration" not in sc) + simulation_conditions = [ + cid + for exp in experiments + for p in exp.periods + for cid in p.condition_ids + ] + dynamic_conditions = list( + sc for sc in simulation_conditions if "preequilibration" not in sc + ) dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) conditions = { "dynamic_conditions": dynamic_conditions, @@ -1750,18 +1840,8 @@ def run_simulations( preresults = { "stats_preeq": None, } - preeqs_array = jnp.stack( - [ - jnp.array([]) - for _ in experiments - ] - ) - h_preeqs = jnp.stack( - [ - jnp.array([]) - for _ in experiments - ] - ) + preeqs_array = jnp.stack([jnp.array([]) for _ in experiments]) + h_preeqs = jnp.stack([jnp.array([]) for _ in experiments]) output, results = problem.run_simulations( experiments, @@ -1855,7 +1935,10 @@ def petab_simulate( f"{petabv2.C.CONDITION_ID} == '{sc}'" )[petabv2.C.OBSERVABLE_PARAMETERS] ) - if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + if ( + petabv2.C.NOISE_PARAMETERS + in problem._petab_problem.measurement_df + ): df_sc[petabv2.C.NOISE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.CONDITION_ID} == '{sc}'" @@ -1872,5 +1955,3 @@ def petab_simulate( ) dfs.append(df_sc) return pd.concat(dfs).sort_index() - - diff --git a/python/sdist/amici/sim/jax/__init__.py b/python/sdist/amici/sim/jax/__init__.py index f17cda0e31..29b0399c9f 100644 --- a/python/sdist/amici/sim/jax/__init__.py +++ b/python/sdist/amici/sim/jax/__init__.py @@ -1,9 +1,9 @@ """Functionality for simulating JAX-based AMICI models.""" +import jax.numpy as jnp +import pandas as pd import petab.v2 as petabv2 -import pandas as pd -import jax.numpy as jnp def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): """Add default experiment names to PEtab v2 problem. @@ -17,18 +17,26 @@ def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): petab_problem.visualization_df = None if petab_problem.condition_df is None: - default_condition = petabv2.core.Condition(id="__default__", changes=[], conditionId="__default__") + default_condition = petabv2.core.Condition( + id="__default__", changes=[], conditionId="__default__" + ) petab_problem.condition_tables[0].elements = [default_condition] - if petab_problem.experiment_df is None or petab_problem.experiment_df.empty: - condition_ids = petab_problem.condition_df[petabv2.C.CONDITION_ID].values - condition_ids = [c for c in condition_ids if "preequilibration" not in c] + if ( + petab_problem.experiment_df is None + or petab_problem.experiment_df.empty + ): + condition_ids = petab_problem.condition_df[ + petabv2.C.CONDITION_ID + ].values + condition_ids = [ + c for c in condition_ids if "preequilibration" not in c + ] default_experiment = petabv2.core.Experiment( id="__default__", periods=[ petabv2.core.ExperimentPeriod( - time=0.0, - condition_ids=condition_ids + time=0.0, condition_ids=condition_ids ) ], ) @@ -43,6 +51,7 @@ def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): return petab_problem + def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: """Get simulation conditions from PEtab v2 measurement DataFrame. @@ -59,6 +68,7 @@ def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: experiment_df = experiment_df.drop(columns=[petabv2.C.TIME]) return experiment_df + def _build_simulation_df_v2(problem, y, dyn_conditions): """Build petab simulation DataFrame of similation results from a PEtab v2 problem.""" dfs = [] @@ -93,7 +103,7 @@ def _build_simulation_df_v2(problem, y, dyn_conditions): if ( petabv2.C.OBSERVABLE_PARAMETERS in problem._petab_problem.measurement_df - ): + ): df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" @@ -108,13 +118,16 @@ def _build_simulation_df_v2(problem, y, dyn_conditions): dfs.append(df_sc) return pd.concat(dfs).sort_index() -def _conditions_to_experiment_map(experiment_df: pd.DataFrame) -> dict[str, str]: + +def _conditions_to_experiment_map( + experiment_df: pd.DataFrame, +) -> dict[str, str]: condition_to_experiment = { - row.conditionId: row.experimentId - for row in experiment_df.itertuples() + row.conditionId: row.experimentId for row in experiment_df.itertuples() } return condition_to_experiment + def _try_float(value): try: return float(value) @@ -122,4 +135,4 @@ def _try_float(value): msg = str(e).lower() if isinstance(e, ValueError) and "could not convert" in msg: return value - raise \ No newline at end of file + raise diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index d20c8459f1..d6cdb45c1b 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -355,7 +355,7 @@ def test_time_dependent_discontinuity(tmp_path): sbml = antimony2sbml(ant_model) importer = SbmlImporter(sbml, from_file=False) - try: + try: importer.sbml2jax("time_disc", output_dir=tmp_path) module = amici._module_from_path("time_disc", tmp_path / "__init__.py") @@ -366,7 +366,9 @@ def test_time_dependent_discontinuity(tmp_path): tcl = model._tcl(x0_full, p) x0 = model._x_solver(x0_full) ts = jnp.array([0.0, 1.0, 2.0]) - h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) + h = model._initialise_heaviside_variables( + 0.0, model._x_solver(x0), p, tcl + ) assert len(model._root_cond_fns()) > 0 assert model._known_discs(p).size == 0 @@ -421,14 +423,18 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): try: importer.sbml2jax("time_disc_eq", output_dir=tmp_path) - module = amici._module_from_path("time_disc_eq", tmp_path / "__init__.py") + module = amici._module_from_path( + "time_disc_eq", tmp_path / "__init__.py" + ) model = module.Model() p = jnp.array([1.0]) x0_full = model._x0(0.0, p) tcl = model._tcl(x0_full, p) x0 = model._x_solver(x0_full) - h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) + h = model._initialise_heaviside_variables( + 0.0, model._x_solver(x0), p, tcl + ) assert len(model._root_cond_fns()) > 0 assert model._known_discs(p).size == 0 diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index a8c1b951e8..3ab174c6fb 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -246,10 +246,10 @@ def model_steadystate_module(): observables = amici.assignment_rules_to_observables( sbml_importer.sbml_model, - filter_function=lambda variable: variable.getId().startswith( - "observable_" - ) - and not variable.getId().endswith("_sigma"), + filter_function=lambda variable: ( + variable.getId().startswith("observable_") + and not variable.getId().endswith("_sigma") + ), as_dict=True, ) observables[ @@ -529,7 +529,9 @@ def model_test_likelihoods(tempdir): MC( "o7", formula="x1", - noise_distribution=lambda str_symbol: f"Abs({str_symbol} - m{str_symbol}) / sigma{str_symbol}", + noise_distribution=lambda str_symbol: ( + f"Abs({str_symbol} - m{str_symbol}) / sigma{str_symbol}" + ), ), ] diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 270b994cb3..084ec787bb 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -10,7 +10,7 @@ from amici.importers.petab.v1 import ( import_petab_problem, ) -from amici.jax.petab import run_simulations, DEFAULT_CONTROLLER_SETTINGS +from amici.jax.petab import run_simulations from amici.sim.sundials import SensitivityMethod, SensitivityOrder from amici.sim.sundials.petab.v1 import ( LLH, @@ -111,7 +111,10 @@ def test_jax_llh(benchmark_problem): lambda x: x.parameters, jax_problem, jnp.array( - [problem_parameters[pid] for pid in jax_problem.parameter_ids] + [ + problem_parameters[pid] + for pid in jax_problem.parameter_ids + ] ), ) @@ -153,7 +156,9 @@ def test_jax_llh(benchmark_problem): sllh_amici = r_amici[SLLH] np.testing.assert_allclose( sllh_jax.parameters, - np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]), + np.array( + [sllh_amici[pid] for pid in jax_problem.parameter_ids] + ), rtol=1e-2, atol=1e-2, err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}", diff --git a/tests/petab_test_suite/test_petab_v2_suite.py b/tests/petab_test_suite/test_petab_v2_suite.py index d43cec15c2..7a3b60675e 100755 --- a/tests/petab_test_suite/test_petab_v2_suite.py +++ b/tests/petab_test_suite/test_petab_v2_suite.py @@ -5,6 +5,7 @@ import sys import diffrax +import jax import pandas as pd import petabtests import pytest @@ -21,7 +22,6 @@ ) from amici.sim.sundials.petab import PetabSimulator from petab import v2 -import jax logger = get_logger(__name__, logging.DEBUG) set_log_level(get_logger("amici.petab_import"), logging.DEBUG) @@ -70,7 +70,6 @@ def _test_case(case, model_type, version, jax): f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}" ) - if jax: from amici.jax import petab_simulate, run_simulations from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS @@ -90,28 +89,25 @@ def _test_case(case, model_type, version, jax): if case.startswith("0016"): controller = diffrax.PIDController( - **DEFAULT_CONTROLLER_SETTINGS, - dtmax=0.5 + **DEFAULT_CONTROLLER_SETTINGS, dtmax=0.5 ) else: - controller = diffrax.PIDController( - **DEFAULT_CONTROLLER_SETTINGS - ) + controller = diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS) llh, _ = run_simulations( - jax_problem, - steady_state_event=steady_state_event, + jax_problem, + steady_state_event=steady_state_event, controller=controller, ) chi2, _ = run_simulations( - jax_problem, - ret="chi2", - steady_state_event=steady_state_event, + jax_problem, + ret="chi2", + steady_state_event=steady_state_event, controller=controller, ) simulation_df = petab_simulate( - jax_problem, - steady_state_event=steady_state_event, + jax_problem, + steady_state_event=steady_state_event, controller=controller, ) else: @@ -137,7 +133,9 @@ def _test_case(case, model_type, version, jax): ) chi2 = sum(rdata.chi2 for rdata in rdatas) llh = res.llh - simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem) + simulation_df = rdatas_to_simulation_df( + rdatas, ps.model, pi.petab_problem + ) solution = petabtests.load_solution(case, model_type, version=version) gt_chi2 = solution[petabtests.CHI2] @@ -247,12 +245,12 @@ def run(): n_total = 0 version = "v2.0.0" - for jax in (False, True): + for jax_ in (False, True): cases = list(petabtests.get_cases("sbml", version=version)) n_total += len(cases) for case in cases: try: - test_case(case, "sbml", version=version, jax=jax) + test_case(case, "sbml", version=version, jax=jax_) n_success += 1 except Skipped: n_skipped += 1