From 6cea00b76922ad1673f52beeee73189bed1a5ed2 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 1 Dec 2025 11:40:41 +0000 Subject: [PATCH 01/28] event assignments jax - sbml cases 348 - 404 --- python/sdist/amici/_symbolic/de_model.py | 7 +- .../amici/importers/petab/_petab_importer.py | 2 +- .../amici/importers/petab/v1/sbml_import.py | 1 + python/sdist/amici/importers/sbml/__init__.py | 15 +++ python/sdist/amici/jax/_simulation.py | 66 +++++++++---- python/sdist/amici/jax/jax.template.py | 25 ++--- python/sdist/amici/jax/model.py | 99 +++++++++++++++---- python/sdist/amici/jax/ode_export.py | 8 +- tests/sbml/testSBMLSuiteJax.py | 4 +- 9 files changed, 165 insertions(+), 62 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 6ac2650628..6f7a0852b4 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -1592,7 +1592,12 @@ def _compute_equation(self, name: str) -> None: else: event_eqs.append(state_update) - self._eqs[name] = event_eqs + self._eqs[name] = sp.Matrix(event_eqs) + + elif name == "x_old": + self._eqs[name] = sp.Matrix( + [state.get_x_rdata() for state in self.states()] + ) elif name == "z": event_observables = [ diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index 17c022766f..a992f2e338 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -336,7 +336,7 @@ def _do_import_sbml(self): show_model_info(self.petab_problem.model.sbml_model) sbml_importer = amici.SbmlImporter( - self.petab_problem.model.sbml_model, + self.petab_problem.model.sbml_model, jax=self._jax ) self._check_placeholders() diff --git a/python/sdist/amici/importers/petab/v1/sbml_import.py b/python/sdist/amici/importers/petab/v1/sbml_import.py index 2c7c308024..5de63dd9b3 100644 --- a/python/sdist/amici/importers/petab/v1/sbml_import.py +++ b/python/sdist/amici/importers/petab/v1/sbml_import.py @@ -331,6 +331,7 @@ def import_model_sbml( sbml_importer = amici.SbmlImporter( sbml_model, discard_annotations=discard_sbml_annotations, + jax=jax, ) sbml_model = sbml_importer.sbml_model diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 55cb4b3079..9e46f63363 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -100,6 +100,7 @@ def __init__( show_sbml_warnings: bool = False, from_file: bool = True, discard_annotations: bool = False, + jax: bool = False, ) -> None: """ Initialize. @@ -187,6 +188,7 @@ def __init__( ignore_units=True, evaluate=True, ) + self.jax = jax @log_execution_time("loading SBML", logger) def _process_document(self) -> None: @@ -1880,6 +1882,19 @@ def _process_events(self) -> None: "priority": self._sympify(event.getPriority()), } + if self.jax: + # Add a negative event for JAX models to handle + neg_event_id = event_id + "_negative" + neg_event_sym = sp.Symbol(neg_event_id) + self.symbols[SymbolId.EVENT][neg_event_sym] = { + "name": neg_event_id, + "value": -trigger, + "assignments": None, + "initial_value": not initial_value, + "use_values_from_trigger_time": use_trig_val, + "priority": self._sympify(event.getPriority()), + } + @log_execution_time("processing observation model", logger) def _process_observation_model( self, diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 605791d33e..62928b2bf4 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -33,6 +33,7 @@ def eq( term: diffrax.ODETerm, root_cond_fns: list[Callable], root_cond_fn: Callable, + delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nxs"], jt.Float[jt.Array, "ne"], dict]: @@ -61,6 +62,8 @@ def eq( list of individual root condition functions for discontinuities :param root_cond_fn: root condition function for all discontinuities + :param delta_x: + function to compute state changes at events :param known_discs: known discontinuities, used to clip the step size controller :param max_steps: @@ -138,17 +141,14 @@ def body_fn(carry): y0_next, t0_next, h_next, stats = _handle_event( t0_next, - jnp.inf, y0_next, p, tcl, h, - solver, - controller, root_finder, - diffrax.DirectAdjoint(), term, root_cond_fn, + delta_x, stats, ) @@ -186,7 +186,9 @@ def solve( term: diffrax.ODETerm, root_cond_fns: list[Callable], root_cond_fn: Callable, + delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], + t_eps: jt.Float = 1e-6, ) -> tuple[jt.Float[jt.Array, "nt nxs"], jt.Float[jt.Array, "nt ne"], dict]: """ Simulate the ODE system for the specified timepoints. @@ -213,6 +215,8 @@ def solve( list of individual root condition functions for discontinuities :param root_cond_fn: root condition function for all discontinuities + :param delta_x: + function to compute state changes at events :param known_discs: known discontinuities, used to clip the step size controller :return: @@ -254,7 +258,7 @@ def cond_fn(carry): def body_fn(carry): ys, t_start, y0, hs, h, stats = carry - sol, idx, stats = _run_segment( + sol, _, stats = _run_segment( t_start, ts[-1], y0, @@ -267,7 +271,7 @@ def body_fn(carry): max_steps, # TODO: figure out how to pass `max_steps - stats['num_steps']` here adjoint, root_cond_fns, - [True] * len(root_cond_fns), + [None] * len(root_cond_fns), diffrax.SaveAt( subs=[ diffrax.SubSaveAt( @@ -297,24 +301,39 @@ def body_fn(carry): y0_next, t0_next, h_next, stats = _handle_event( t0_next, - ts_next, y0_next, p, tcl, h, - solver, - controller, root_finder, - adjoint, term, root_cond_fn, + delta_x, stats, ) - was_event = jnp.isin(ts, sol.ts[1]) - hs = jnp.where(was_event[:, None], h_next[None, :], hs) + after_event = sol.ts[1] < ts + hs = jnp.where(after_event[:, None], h_next[None, :], hs) - return ys, t0_next, y0_next, hs, h_next, stats + # Advance state to stop retriggering event immediately + t_resume = t0_next + t_eps + small_step = diffrax.diffeqsolve( + term, + solver, + t0_next, + t0_next + t_eps, + dt0=None, + y0=y0_next, + stepsize_controller=controller, + args=(p, tcl, h), + saveat=diffrax.SaveAt(t1=True), + event=None, + ) + + t_resume = small_step.ts[-1] + y_resume = small_step.ys[-1] + + return ys, t_resume, y_resume, hs, h_next, stats # run the loop until we have reached the end of the time points ys, _, _, hs, _, stats = eqxi.while_loop( @@ -429,17 +448,14 @@ def _run_segment( def _handle_event( t0_next: float, - t_max: float, y0_next: jt.Float[jt.Array, "nxs"], p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], h: jt.Float[jt.Array, "ne"], - solver: diffrax.AbstractSolver, - controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, - adjoint: diffrax.AbstractAdjoint, term: diffrax.ODETerm, root_cond_fn: Callable, + delta_x: Callable, stats: dict, ): args = (p, tcl, h) @@ -457,11 +473,23 @@ def _handle_event( ) roots_dir = jnp.sign(droot_dt) # direction of the root condition function - h_next = h + jnp.where( + overall_dir = jnp.sign(jnp.sum(h * roots_dir)) + h_next = h - (overall_dir * jnp.where( roots_found, roots_dir, jnp.zeros_like(h), - ) # update heaviside variables based on the root condition function + )) # update heaviside variables based on the root condition function + + mask = jnp.array( + [ + jnp.logical_and(roots_found, roots_dir > 0.0) + for _ in range(y0_next.shape[0]) + ] + ).T + delx = delta_x(y0_next, p) + ups_mat = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) + y0_up = jnp.where(mask, ups_mat, 0.0) + y0_next = y0_next + jnp.sum(y0_up, axis=0) if os.getenv("JAX_DEBUG") == "1": jax.debug.print( diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index b5247d2eab..2f4cba95f5 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -126,23 +126,16 @@ def _root_cond_fn(self, t, y, args, **_): TPL_IROOT_EQ return TPL_IROOT_RET + + def _delta_x(self, y, p): + TPL_X_SYMS = y + TPL_P_SYMS = p + + TPL_X_OLD_EQ + + TPL_DELTAX_EQ - def _root_cond_fn_event(self, ie, t, y, args, **_): - """ - Root condition function for a specific event index. - """ - __, __, h = args - rval = self._root_cond_fn(t, y, args, **_) - # only allow root triggers where trigger function is negative (heaviside == 0) - masked_rval = jnp.where(h == 0.0, rval, 1.0) - return masked_rval.at[ie].get() - - def _root_cond_fns(self): - """Return root condition functions for discontinuities.""" - return [ - eqx.Partial(self._root_cond_fn_event, ie) - for ie in range(self.n_events) - ] + return TPL_DELTAX_RET @property def n_events(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index c8a3f56014..f8e11ae3e5 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -15,6 +15,8 @@ import jaxtyping as jt from optimistix import AbstractRootFinder +import os + from ._simulation import eq, solve @@ -267,22 +269,6 @@ def _known_discs( """ ... - @abstractmethod - def _root_cond_fns( - self, - ) -> list[Callable[[float, jt.Float[jt.Array, "nxs"], tuple], jt.Float]]: - """Return condition functions for implicit discontinuities. - - These functions are passed to :class:`diffrax.Event` and must evaluate - to zero when a discontinuity is triggered. - - :param p: - model parameters - :return: - tuple of callable root functions - """ - ... - @abstractmethod def _root_cond_fn( self, @@ -308,6 +294,20 @@ def _root_cond_fn( """ ... + @abstractmethod + def _delta_x( + self, y: jt.Float[jt.Array, "nxs"] + ) -> jt.Float[jt.Array, "nxs"]: + """ + Compute the state vector changes at discontinuities. + + :param y: + state vector + :return: + changes in the state vector at discontinuities + """ + ... + @property @abstractmethod def n_events(self) -> int: @@ -362,12 +362,66 @@ def expression_ids(self) -> list[str]: """ ... + def _root_cond_fn_event( + self, + ie: int, + t: float, + y: jt.Float[jt.Array, "nxs"], + args: tuple, + **_ + ): + """ + Root condition function for a specific event index. + + :param ie: + event index + :param t: + time point + :param y: + state vector + :param args: + tuple of arguments required for _root_cond_fn + :return: + mask of root condition value for the specified event index + """ + __, __, h = args + rval = self._root_cond_fn(t, y, args, **_) + + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "rval: {}, ie: {}, h[ie]: {}, t: {}", + rval, + ie, + h[ie], + t, + ) + # only allow root triggers where trigger function is negative (heaviside == 0) + masked_rval = jnp.where(h == 0.0, rval, 1.0) + return masked_rval.at[ie].get() + + def _root_cond_fns(self) -> list[Callable[[float, jt.Float[jt.Array, "nxs"], tuple], jt.Float]]: + """Return condition functions for implicit discontinuities. + + These functions are passed to :class:`diffrax.Event` and must evaluate + to zero when a discontinuity is triggered. + + :param p: + model parameters + :return: + iterable of callable root functions + """ + return [ + eqx.Partial(self._root_cond_fn_event, ie) + for ie in range(self.n_events) + ] + def _initialise_heaviside_variables( self, t0: jt.Float[jt.Scalar, ""], x_solver: jt.Float[jt.Array, "nxs"], p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], + root_finder: AbstractRootFinder, ) -> jt.Float[jt.Array, "ne"]: """ Initialise the heaviside variables. @@ -384,9 +438,13 @@ def _initialise_heaviside_variables( heaviside variables """ h0 = jnp.zeros((self.n_events,)) # dummy values - roots_found = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) + # QUERY: should roots_dir be accounted for here, as in _handle_event? + rootvals = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) + roots_found = jnp.isclose( + rootvals, 0.0, atol=root_finder.atol, rtol=root_finder.rtol + ) return jnp.where( - roots_found >= 0.0, jnp.ones_like(h0), jnp.zeros_like(h0) + roots_found > 0.0, jnp.ones_like(h0), jnp.zeros_like(h0) ) def _x_rdatas( @@ -576,7 +634,7 @@ def simulate_condition_unjitted( x = jnp.where(mask_reinit, x_reinit, x) x_solver = self._x_solver(x) tcl = self._tcl(x, p) - h = self._initialise_heaviside_variables(t0, x_solver, p, tcl) + h = self._initialise_heaviside_variables(t0, x_solver, p, tcl, root_finder) # Dynamic simulation if ts_dyn.shape[0]: @@ -594,6 +652,7 @@ def simulate_condition_unjitted( diffrax.ODETerm(self._xdot), self._root_cond_fns(), self._root_cond_fn, + self._delta_x, self._known_discs(p), ) x_solver = x_dyn[-1, :] @@ -616,6 +675,7 @@ def simulate_condition_unjitted( diffrax.ODETerm(self._xdot), self._root_cond_fns(), self._root_cond_fn, + self._delta_x, self._known_discs(p), max_steps, ) @@ -852,6 +912,7 @@ def preequilibrate_condition( diffrax.ODETerm(self._xdot), self._root_cond_fns(), self._root_cond_fn, + self._delta_x, self._known_discs(p), max_steps, ) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index f1c6d60a95..7b51e68043 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -144,11 +144,6 @@ def __init__( """ set_log_level(logger, verbose) - if ode_model.has_event_assignments(): - raise NotImplementedError( - "The JAX backend does not support models with event assignments." - ) - if ode_model.has_algebraic_states(): raise NotImplementedError( "The JAX backend does not support models with algebraic states." @@ -201,6 +196,8 @@ def _generate_jax_code(self) -> None: "x_rdata", "total_cl", "iroot", + "deltax", + "x_old", ) sym_names = ( "p", @@ -215,6 +212,7 @@ def _generate_jax_code(self) -> None: "sigmay", "x_rdata", "iroot", + "x_old", ) indent = 8 diff --git a/tests/sbml/testSBMLSuiteJax.py b/tests/sbml/testSBMLSuiteJax.py index 2ddb36820e..e488f8d663 100644 --- a/tests/sbml/testSBMLSuiteJax.py +++ b/tests/sbml/testSBMLSuiteJax.py @@ -51,7 +51,7 @@ def get_expression_ids(self): def compile_model_jax(sbml_dir: Path, test_id: str, model_dir: Path): model_dir.mkdir(parents=True, exist_ok=True) sbml_file = find_model_file(sbml_dir, test_id) - sbml_importer = amici.SbmlImporter(sbml_file) + sbml_importer = amici.SbmlImporter(sbml_file, jax=True) model_name = f"SBMLTest{test_id}_jax" sbml_importer.sbml2jax(model_name, output_dir=model_dir) model_module = amici.import_model_module(model_dir.name, model_dir.parent) @@ -159,6 +159,8 @@ def test_sbml_testsuite_case_jax( 276, 277, 279, + 356, + 357, 1148, 1159, 1160, From 46ff88469c0c7259bb4cd612b84dcc116be992a9 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 09:36:32 +0000 Subject: [PATCH 02/28] fix up sbml test cases - not implemented priority, update t_eps, fix heaviside init --- python/sdist/amici/_symbolic/de_model.py | 17 ++++++++++++---- python/sdist/amici/jax/_simulation.py | 11 +++++----- python/sdist/amici/jax/jax.template.py | 10 +++++---- python/sdist/amici/jax/model.py | 26 ++++++------------------ python/sdist/amici/jax/ode_export.py | 19 ++++++++++++++++- tests/sbml/testSBMLSuiteJax.py | 1 + 6 files changed, 50 insertions(+), 34 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 6f7a0852b4..4bcda4ebfe 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2639,19 +2639,19 @@ def get_explicit_roots(self) -> set[sp.Expr]: """ return {root for e in self._events for root in e.get_trigger_times()} - def get_implicit_roots(self) -> set[sp.Expr]: + def get_implicit_roots(self) -> list[sp.Expr]: """ Returns implicit equations for all discontinuities (events) that have to be located via rootfinding :return: - set of symbolic roots + list of symbolic roots """ - return { + return [ e.get_val() for e in self._events if not e.has_explicit_trigger_times() - } + ] def has_algebraic_states(self) -> bool: """ @@ -2670,6 +2670,15 @@ def has_event_assignments(self) -> bool: boolean indicating if event assignments are present """ return any(event.updates_state for event in self._events) + + def has_priority_events(self) -> bool: + """ + Checks whether the model has events with priorities defined + + :return: + boolean indicating if priority events are present + """ + return any(event.get_priority() is not None for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 62928b2bf4..4efda0d26d 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -188,7 +188,7 @@ def solve( root_cond_fn: Callable, delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], - t_eps: jt.Float = 1e-6, + t_eps: jt.Float = 1e-5, ) -> tuple[jt.Float[jt.Array, "nt nxs"], jt.Float[jt.Array, "nt ne"], dict]: """ Simulate the ODE system for the specified timepoints. @@ -473,7 +473,7 @@ def _handle_event( ) roots_dir = jnp.sign(droot_dt) # direction of the root condition function - overall_dir = jnp.sign(jnp.sum(h * roots_dir)) + overall_dir = jnp.sign(jnp.sum(jnp.where(roots_found, h * roots_dir, 0.0))) h_next = h - (overall_dir * jnp.where( roots_found, roots_dir, @@ -486,9 +486,10 @@ def _handle_event( for _ in range(y0_next.shape[0]) ] ).T - delx = delta_x(y0_next, p) - ups_mat = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) - y0_up = jnp.where(mask, ups_mat, 0.0) + delx = delta_x(y0_next, p, t0_next) + if y0_next.size: + delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) + y0_up = jnp.where(mask, delx, 0.0) y0_next = y0_next + jnp.sum(y0_up, axis=0) if os.getenv("JAX_DEBUG") == "1": diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 2f4cba95f5..95bc42c860 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -109,8 +109,9 @@ def _nllh(self, t, x, p, tcl, h, my, iy, op, np): return TPL_JY_RET.at[iy].get() - def _known_discs(self, p): + def _known_discs(self, p, y): TPL_P_SYMS = p + TPL_X_SYMS = y return TPL_ROOTS @@ -124,10 +125,11 @@ def _root_cond_fn(self, t, y, args, **_): TPL_W_SYMS = self._w(t, y, p, tcl, h) TPL_IROOT_EQ + TPL_EROOT_EQ - return TPL_IROOT_RET + return jnp.hstack((TPL_IROOT_RET, TPL_EROOT_RET)) - def _delta_x(self, y, p): + def _delta_x(self, y, p, t): TPL_X_SYMS = y TPL_P_SYMS = p @@ -139,7 +141,7 @@ def _delta_x(self, y, p): @property def n_events(self): - return TPL_N_IEVENTS + return TPL_N_IEVENTS + TPL_N_EEVENTS @property def observable_ids(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index f8e11ae3e5..968b1f96f7 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -386,15 +386,6 @@ def _root_cond_fn_event( """ __, __, h = args rval = self._root_cond_fn(t, y, args, **_) - - if os.getenv("JAX_DEBUG") == "1": - jax.debug.print( - "rval: {}, ie: {}, h[ie]: {}, t: {}", - rval, - ie, - h[ie], - t, - ) # only allow root triggers where trigger function is negative (heaviside == 0) masked_rval = jnp.where(h == 0.0, rval, 1.0) return masked_rval.at[ie].get() @@ -421,7 +412,6 @@ def _initialise_heaviside_variables( x_solver: jt.Float[jt.Array, "nxs"], p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], - root_finder: AbstractRootFinder, ) -> jt.Float[jt.Array, "ne"]: """ Initialise the heaviside variables. @@ -438,13 +428,9 @@ def _initialise_heaviside_variables( heaviside variables """ h0 = jnp.zeros((self.n_events,)) # dummy values - # QUERY: should roots_dir be accounted for here, as in _handle_event? - rootvals = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) - roots_found = jnp.isclose( - rootvals, 0.0, atol=root_finder.atol, rtol=root_finder.rtol - ) + roots_found = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) return jnp.where( - roots_found > 0.0, jnp.ones_like(h0), jnp.zeros_like(h0) + roots_found >= 0.0, jnp.ones_like(h0), jnp.zeros_like(h0) ) def _x_rdatas( @@ -634,7 +620,7 @@ def simulate_condition_unjitted( x = jnp.where(mask_reinit, x_reinit, x) x_solver = self._x_solver(x) tcl = self._tcl(x, p) - h = self._initialise_heaviside_variables(t0, x_solver, p, tcl, root_finder) + h = self._initialise_heaviside_variables(t0, x_solver, p, tcl) # Dynamic simulation if ts_dyn.shape[0]: @@ -653,7 +639,7 @@ def simulate_condition_unjitted( self._root_cond_fns(), self._root_cond_fn, self._delta_x, - self._known_discs(p), + self._known_discs(p, x_solver), ) x_solver = x_dyn[-1, :] else: @@ -676,7 +662,7 @@ def simulate_condition_unjitted( self._root_cond_fns(), self._root_cond_fn, self._delta_x, - self._known_discs(p), + self._known_discs(p, x_solver), max_steps, ) else: @@ -913,7 +899,7 @@ def preequilibrate_condition( self._root_cond_fns(), self._root_cond_fn, self._delta_x, - self._known_discs(p), + self._known_discs(p, current_x), max_steps, ) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 7b51e68043..80499baf17 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -149,6 +149,11 @@ def __init__( "The JAX backend does not support models with algebraic states." ) + if ode_model.has_priority_events(): + raise NotImplementedError( + "The JAX backend does not support event priorities." + ) + self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG self.model_path: Path = Path() @@ -195,6 +200,7 @@ def _generate_jax_code(self) -> None: "x_solver", "x_rdata", "total_cl", + "eroot", "iroot", "deltax", "x_old", @@ -250,12 +256,13 @@ def _generate_jax_code(self) -> None: "P_VALUES": _jnp_array_str(self.model.val("p")), "ROOTS": _jnp_array_str( { - root + _parse_trigger_root(root) for e in self.model._events for root in e.get_trigger_times() } ), "N_IEVENTS": str(len(self.model.get_implicit_roots())), + "N_EEVENTS": str(len(self.model.get_explicit_roots())), **{ "MODEL_NAME": self.model_name, # keep track of the API version that the model was generated with so we @@ -331,3 +338,13 @@ def set_name(self, model_name: str) -> None: ) self.model_name = model_name + +def _parse_trigger_root(root: sp.Expr) -> str: + """Convert a trigger root expression into a string representation. + + :param root: The trigger root expression. + :return: A string representation of the trigger root. + """ + if root.is_number: + return float(root) + return str(root).replace(" ", "") diff --git a/tests/sbml/testSBMLSuiteJax.py b/tests/sbml/testSBMLSuiteJax.py index e488f8d663..772c881e49 100644 --- a/tests/sbml/testSBMLSuiteJax.py +++ b/tests/sbml/testSBMLSuiteJax.py @@ -161,6 +161,7 @@ def test_sbml_testsuite_case_jax( 279, 356, 357, + 752, 1148, 1159, 1160, From 728e82800de5817a1f2102e4f8b71493aca1b6c3 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 10:29:03 +0000 Subject: [PATCH 03/28] initialValue False not implemented --- python/sdist/amici/importers/sbml/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 9e46f63363..17b4ab5085 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -1883,6 +1883,10 @@ def _process_events(self) -> None: } if self.jax: + if not initial_value: + raise NotImplementedError( + "The JAX backend does not support events with False initialValue." + ) # Add a negative event for JAX models to handle neg_event_id = event_id + "_negative" neg_event_sym = sp.Symbol(neg_event_id) From 20495d28ae57a85c99547e0f3aba04f8e349ec38 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 11:44:38 +0000 Subject: [PATCH 04/28] try fix other test cases --- python/sdist/amici/_symbolic/de_model.py | 20 ++++++++++++++------ python/sdist/amici/jax/ode_export.py | 5 +++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 4bcda4ebfe..c8444ff829 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -1592,7 +1592,7 @@ def _compute_equation(self, name: str) -> None: else: event_eqs.append(state_update) - self._eqs[name] = sp.Matrix(event_eqs) + self._eqs[name] = event_eqs elif name == "x_old": self._eqs[name] = sp.Matrix( @@ -2639,20 +2639,19 @@ def get_explicit_roots(self) -> set[sp.Expr]: """ return {root for e in self._events for root in e.get_trigger_times()} - def get_implicit_roots(self) -> list[sp.Expr]: + def get_implicit_roots(self) -> set[sp.Expr]: """ Returns implicit equations for all discontinuities (events) that have to be located via rootfinding :return: - list of symbolic roots + set of symbolic roots """ - return [ + return { e.get_val() for e in self._events if not e.has_explicit_trigger_times() - ] - + } def has_algebraic_states(self) -> bool: """ Checks whether the model has algebraic states @@ -2679,6 +2678,15 @@ def has_priority_events(self) -> bool: boolean indicating if priority events are present """ return any(event.get_priority() is not None for event in self._events) + + def has_implicit_event_assignments(self) -> bool: + """ + Checks whether the model has event assignments with implicit triggers + + :return: + boolean indicating if event assignments with implicit triggers are present + """ + return any(event.updates_state and not event.has_explicit_trigger_times() for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 80499baf17..9103979782 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -153,6 +153,11 @@ def __init__( raise NotImplementedError( "The JAX backend does not support event priorities." ) + + if ode_model.has_implicit_event_assignments(): + raise NotImplementedError( + "The JAX backend does not support event assignments with implicit triggers." + ) self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG From 1533f8f6914aefbd45df6330d02f9a8f97736730 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 11:50:41 +0000 Subject: [PATCH 05/28] Matrix only for JAX event assignments --- python/sdist/amici/jax/ode_export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 9103979782..1b7534e34d 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -60,7 +60,8 @@ def _jax_variable_equations( f"{eq_name.upper()}_EQ": "\n".join( code_printer._get_sym_lines( (s.name for s in model.sym(eq_name)), - model.eq(eq_name).subs(subs), + # sp.Matrix to support event assignments which are lists + sp.Matrix(model.eq(eq_name)).subs(subs), indent, ) )[indent:] # remove indent for first line From 67c2eb0b9f7f65332c4156c5600cdd5e54d1a0a9 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 12:47:17 +0000 Subject: [PATCH 06/28] params only in explicit triggers - and matrix only in JAX again --- python/sdist/amici/_symbolic/de_model.py | 18 ++++++++++-------- python/sdist/amici/jax/jax.template.py | 3 +-- python/sdist/amici/jax/model.py | 6 +++--- python/sdist/amici/jax/ode_export.py | 4 +++- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index c8444ff829..19fbbb428b 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2629,29 +2629,30 @@ def _process_hybridization(self, hybridization: dict) -> None: if added_expressions: self.toposort_expressions() - def get_explicit_roots(self) -> set[sp.Expr]: + def get_explicit_roots(self) -> list[sp.Expr]: """ Returns explicit formulas for all discontinuities (events) that can be precomputed :return: - set of symbolic roots + list of symbolic roots """ - return {root for e in self._events for root in e.get_trigger_times()} + return [root for e in self._events for root in e.get_trigger_times()] - def get_implicit_roots(self) -> set[sp.Expr]: + def get_implicit_roots(self) -> list[sp.Expr]: """ Returns implicit equations for all discontinuities (events) that have to be located via rootfinding :return: - set of symbolic roots + list of symbolic roots """ - return { + return [ e.get_val() for e in self._events if not e.has_explicit_trigger_times() - } + ] + def has_algebraic_states(self) -> bool: """ Checks whether the model has algebraic states @@ -2686,7 +2687,8 @@ def has_implicit_event_assignments(self) -> bool: :return: boolean indicating if event assignments with implicit triggers are present """ - return any(event.updates_state and not event.has_explicit_trigger_times() for event in self._events) + allowed_syms = set(self.sym("p")) + return any(event.updates_state and not event.has_explicit_trigger_times(allowed_syms) for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 95bc42c860..a1c05eb799 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -109,9 +109,8 @@ def _nllh(self, t, x, p, tcl, h, my, iy, op, np): return TPL_JY_RET.at[iy].get() - def _known_discs(self, p, y): + def _known_discs(self, p): TPL_P_SYMS = p - TPL_X_SYMS = y return TPL_ROOTS diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 968b1f96f7..113009511e 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -639,7 +639,7 @@ def simulate_condition_unjitted( self._root_cond_fns(), self._root_cond_fn, self._delta_x, - self._known_discs(p, x_solver), + self._known_discs(p), ) x_solver = x_dyn[-1, :] else: @@ -662,7 +662,7 @@ def simulate_condition_unjitted( self._root_cond_fns(), self._root_cond_fn, self._delta_x, - self._known_discs(p, x_solver), + self._known_discs(p), max_steps, ) else: @@ -899,7 +899,7 @@ def preequilibrate_condition( self._root_cond_fns(), self._root_cond_fn, self._delta_x, - self._known_discs(p, current_x), + self._known_discs(p), max_steps, ) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 1b7534e34d..38560abdb0 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -77,7 +77,7 @@ def _jax_return_variables( f"{eq_name.upper()}_RET": _jnp_array_str( s.name for s in model.sym(eq_name) ) - if model.sym(eq_name) + if model.sym(eq_name) and sp.Matrix(model.eq(eq_name)).shape[0] else "jnp.array([])" for eq_name in eq_names } @@ -285,6 +285,8 @@ def _generate_jax_code(self) -> None: ), } + breakpoint() + apply_template( Path(amiciModulePath) / "jax" / "jax.template.py", self.model_path / "__init__.py", From 533b97e03dd6ce56ee9f12709a5f4d7f982bac31 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 14:02:07 +0000 Subject: [PATCH 07/28] oops committed breakpoint --- python/sdist/amici/jax/ode_export.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 38560abdb0..c60ed1c07f 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -285,8 +285,6 @@ def _generate_jax_code(self) -> None: ), } - breakpoint() - apply_template( Path(amiciModulePath) / "jax" / "jax.template.py", self.model_path / "__init__.py", From 6b53bb4463f901f42fc7ac69061ef0d9893170ac Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 14:53:24 +0000 Subject: [PATCH 08/28] fix delta variables in deltax --- python/sdist/amici/_symbolic/de_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 19fbbb428b..836a604ace 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -1142,6 +1142,8 @@ def _generate_symbol(self, name: str) -> None: ] ) return + elif name == "deltax": + length = sp.Matrix(self.eq(name)).shape[0] else: length = len(self.eq(name)) self._syms[name] = sp.Matrix( From af08ed3ac76e65a178b2bd4a97004aadb8ce3c01 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 5 Dec 2025 17:16:10 +0000 Subject: [PATCH 09/28] new param _delta_x missing in solve calls --- python/tests/test_jax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index ae12bcfe9c..dccc5a8ae0 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -372,6 +372,7 @@ def test_time_dependent_discontinuity(tmp_path): diffrax.ODETerm(model._xdot), model._root_cond_fns(), model._root_cond_fn, + model._delta_x, model._known_discs(p), ) @@ -422,6 +423,7 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): diffrax.ODETerm(model._xdot), model._root_cond_fns(), model._root_cond_fn, + model._delta_x, model._known_discs(p), 1000, ) From b4b3219a182b836271df3b6db7959d38542f0756 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 8 Dec 2025 10:30:25 +0000 Subject: [PATCH 10/28] try simpler roots direction logic in handle event --- python/sdist/amici/jax/_simulation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 4efda0d26d..b6f7d29222 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -473,10 +473,9 @@ def _handle_event( ) roots_dir = jnp.sign(droot_dt) # direction of the root condition function - overall_dir = jnp.sign(jnp.sum(jnp.where(roots_found, h * roots_dir, 0.0))) - h_next = h - (overall_dir * jnp.where( + h_next = h - (jnp.where( roots_found, - roots_dir, + h * roots_dir, jnp.zeros_like(h), )) # update heaviside variables based on the root condition function From 0545a070cd82f90c2179ddbbb5107b32861df82c Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 8 Dec 2025 11:31:10 +0000 Subject: [PATCH 11/28] try not logic in handle event --- python/sdist/amici/jax/_simulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index b6f7d29222..d25d19a08e 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -473,11 +473,11 @@ def _handle_event( ) roots_dir = jnp.sign(droot_dt) # direction of the root condition function - h_next = h - (jnp.where( + h_next = jnp.where( roots_found, - h * roots_dir, - jnp.zeros_like(h), - )) # update heaviside variables based on the root condition function + jnp.logical_not(h), + h, + ) # update heaviside variables based on the root condition function mask = jnp.array( [ From 06a208b644f752ae455c87a02a98b01540d44957 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 8 Dec 2025 13:31:54 +0000 Subject: [PATCH 12/28] looking for initialValue test cases support explicit initial value event assignments fix deltax tcl --- python/sdist/amici/importers/sbml/__init__.py | 4 ---- python/sdist/amici/jax/_simulation.py | 2 +- python/sdist/amici/jax/jax.template.py | 7 ++++++- python/sdist/amici/jax/model.py | 21 ++++++++++++++++--- python/sdist/amici/jax/ode_export.py | 5 +++++ 5 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 17b4ab5085..9e46f63363 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -1883,10 +1883,6 @@ def _process_events(self) -> None: } if self.jax: - if not initial_value: - raise NotImplementedError( - "The JAX backend does not support events with False initialValue." - ) # Add a negative event for JAX models to handle neg_event_id = event_id + "_negative" neg_event_sym = sp.Symbol(neg_event_id) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index d25d19a08e..93d7eb85d8 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -485,7 +485,7 @@ def _handle_event( for _ in range(y0_next.shape[0]) ] ).T - delx = delta_x(y0_next, p, t0_next) + delx = delta_x(y0_next, p, tcl) if y0_next.size: delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) y0_up = jnp.where(mask, delx, 0.0) diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index a1c05eb799..eaa0c54982 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -128,15 +128,20 @@ def _root_cond_fn(self, t, y, args, **_): return jnp.hstack((TPL_IROOT_RET, TPL_EROOT_RET)) - def _delta_x(self, y, p, t): + def _delta_x(self, y, p, tcl): TPL_X_SYMS = y TPL_P_SYMS = p + TPL_TCL_SYMS = tcl TPL_X_OLD_EQ TPL_DELTAX_EQ return TPL_DELTAX_RET + + @property + def event_initial_values(self): + return TPL_EVENT_INITIAL_VALUES @property def n_events(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 113009511e..cee5d73f3a 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -17,7 +17,7 @@ import os -from ._simulation import eq, solve +from ._simulation import eq, solve, _handle_event class ReturnValue(enum.Enum): @@ -427,10 +427,12 @@ def _initialise_heaviside_variables( :return: heaviside variables """ - h0 = jnp.zeros((self.n_events,)) # dummy values + h0 = self.event_initial_values.astype(float) # dummy values roots_found = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) return jnp.where( - roots_found >= 0.0, jnp.ones_like(h0), jnp.zeros_like(h0) + jnp.logical_and(roots_found >= 0.0, h0 == 1.0), + jnp.ones_like(h0), + jnp.zeros_like(h0) ) def _x_rdatas( @@ -622,6 +624,19 @@ def simulate_condition_unjitted( tcl = self._tcl(x, p) h = self._initialise_heaviside_variables(t0, x_solver, p, tcl) + x_solver, _, h, _ = _handle_event( + t0, + x_solver, + p, + tcl, + h, + root_finder, + diffrax.ODETerm(self._xdot), + self._root_cond_fn, + self._delta_x, + {}, + ) + # Dynamic simulation if ts_dyn.shape[0]: x_dyn, h_dyn, stats_dyn = solve( diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index c60ed1c07f..5b7e750f56 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -269,6 +269,11 @@ def _generate_jax_code(self) -> None: ), "N_IEVENTS": str(len(self.model.get_implicit_roots())), "N_EEVENTS": str(len(self.model.get_explicit_roots())), + "EVENT_INITIAL_VALUES": _jnp_array_str( + [ + e.get_initial_value() for e in self.model._events + ] + ), **{ "MODEL_NAME": self.model_name, # keep track of the API version that the model was generated with so we From 8398a389cb352400f9ade988085c7b5fd9fd85f9 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 9 Dec 2025 10:09:02 +0000 Subject: [PATCH 13/28] add h = 0 check to handle event --- python/sdist/amici/jax/_simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 93d7eb85d8..5126fb900f 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -481,7 +481,7 @@ def _handle_event( mask = jnp.array( [ - jnp.logical_and(roots_found, roots_dir > 0.0) + (roots_found & (roots_dir > 0.0) & (h == 0.0)) for _ in range(y0_next.shape[0]) ] ).T From 9aa98660f65efe602f44828952804d1b29774ed4 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 9 Dec 2025 10:09:54 +0000 Subject: [PATCH 14/28] do not update h pre-solve --- python/sdist/amici/jax/model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index cee5d73f3a..c6a54b13ff 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -427,7 +427,12 @@ def _initialise_heaviside_variables( :return: heaviside variables """ - h0 = self.event_initial_values.astype(float) # dummy values + h0 = self.event_initial_values.astype(float) + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "h0: {}", + h0, + ) roots_found = self._root_cond_fn(t0, x_solver, (p, tcl, h0)) return jnp.where( jnp.logical_and(roots_found >= 0.0, h0 == 1.0), @@ -624,7 +629,7 @@ def simulate_condition_unjitted( tcl = self._tcl(x, p) h = self._initialise_heaviside_variables(t0, x_solver, p, tcl) - x_solver, _, h, _ = _handle_event( + x_solver, _, _, _ = _handle_event( t0, x_solver, p, From 523cf200a1f8b751e9994a0816a6ebd2c38dcd39 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 9 Dec 2025 15:46:20 +0000 Subject: [PATCH 15/28] handle_t0_event --- python/sdist/amici/jax/_simulation.py | 64 ++++++++++++++------------- python/sdist/amici/jax/model.py | 61 +++++++++++++++++++++---- 2 files changed, 86 insertions(+), 39 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 5126fb900f..f021e30609 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -271,7 +271,7 @@ def body_fn(carry): max_steps, # TODO: figure out how to pass `max_steps - stats['num_steps']` here adjoint, root_cond_fns, - [None] * len(root_cond_fns), + [True] * len(root_cond_fns), diffrax.SaveAt( subs=[ diffrax.SubSaveAt( @@ -315,25 +315,7 @@ def body_fn(carry): after_event = sol.ts[1] < ts hs = jnp.where(after_event[:, None], h_next[None, :], hs) - # Advance state to stop retriggering event immediately - t_resume = t0_next + t_eps - small_step = diffrax.diffeqsolve( - term, - solver, - t0_next, - t0_next + t_eps, - dt0=None, - y0=y0_next, - stepsize_controller=controller, - args=(p, tcl, h), - saveat=diffrax.SaveAt(t1=True), - event=None, - ) - - t_resume = small_step.ts[-1] - y_resume = small_step.ys[-1] - - return ys, t_resume, y_resume, hs, h_next, stats + return ys, t0_next, y0_next, hs, h_next, stats # run the loop until we have reached the end of the time points ys, _, _, hs, _, stats = eqxi.while_loop( @@ -473,6 +455,37 @@ def _handle_event( ) roots_dir = jnp.sign(droot_dt) # direction of the root condition function + y0_next, h_next = _apply_event_assignments( + roots_found, + roots_dir, + y0_next, + p, + tcl, + h, + delta_x, + ) + + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "rootvals: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}", + rootvals, + roots_found, + roots_dir, + h, + h_next, + ) + + return y0_next, t0_next, h_next, stats + +def _apply_event_assignments( + roots_found, + roots_dir, + y0_next, + p, + tcl, + h, + delta_x, +): h_next = jnp.where( roots_found, jnp.logical_not(h), @@ -491,13 +504,4 @@ def _handle_event( y0_up = jnp.where(mask, delx, 0.0) y0_next = y0_next + jnp.sum(y0_up, axis=0) - if os.getenv("JAX_DEBUG") == "1": - jax.debug.print( - "rootvals: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}", - rootvals, - roots_found, - roots_dir, - h, - h_next, - ) - return y0_next, t0_next, h_next, stats + return y0_next, h_next diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index c6a54b13ff..77d2c955e2 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -17,7 +17,7 @@ import os -from ._simulation import eq, solve, _handle_event +from ._simulation import eq, solve, _apply_event_assignments class ReturnValue(enum.Enum): @@ -627,16 +627,12 @@ def simulate_condition_unjitted( x = jnp.where(mask_reinit, x_reinit, x) x_solver = self._x_solver(x) tcl = self._tcl(x, p) - h = self._initialise_heaviside_variables(t0, x_solver, p, tcl) - x_solver, _, _, _ = _handle_event( + x_solver, _, h, _ = self._handle_t0_event( t0, x_solver, p, tcl, - h, - root_finder, - diffrax.ODETerm(self._xdot), self._root_cond_fn, self._delta_x, {}, @@ -902,10 +898,19 @@ def preequilibrate_condition( if x_reinit.shape[0]: x0 = jnp.where(mask_reinit, x_reinit, x0) tcl = self._tcl(x0, p) - h = self._initialise_heaviside_variables( - t0, self._x_solver(x0), p, tcl - ) + current_x = self._x_solver(x0) + + current_x, _, h, _ = self._handle_t0_event( + t0, + self._x_solver(x0), + p, + tcl, + self._root_cond_fn, + self._delta_x, + {}, + ) + current_x, _, stats_preeq = eq( p, tcl, @@ -925,6 +930,44 @@ def preequilibrate_condition( return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) + def _handle_t0_event( + self, + t0_next: float, + y0_next: jt.Float[jt.Array, "nxs"], + p: jt.Float[jt.Array, "np"], + tcl: jt.Float[jt.Array, "ncl"], + root_cond_fn: Callable, + delta_x: Callable, + stats: dict, + ): + h = self.event_initial_values + args = (p, tcl, h) + rfx = root_cond_fn(t0_next, y0_next, args) + roots_dir = jnp.sign(rfx - h) + roots_found = (rfx - h) == 0.0 + + y0_next, h_next = _apply_event_assignments( + roots_found, + roots_dir, + y0_next, + p, + tcl, + h, + delta_x, + ) + + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "h: {}, rfx: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}", + h, + rfx, + roots_found, + roots_dir, + h, + h_next, + ) + + return y0_next, t0_next, h_next, stats def safe_log(x: jnp.float_) -> jnp.float_: """ From 7f5fdab461a5f88b67debc2afbc1e3f79a8db20f Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 9 Dec 2025 16:49:45 +0000 Subject: [PATCH 16/28] reinstate time skip (hack diffrax bug?) --- python/sdist/amici/jax/_simulation.py | 30 ++++++++++++++++++++------- python/sdist/amici/jax/model.py | 11 +++++----- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index f021e30609..f65e896d3c 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -139,7 +139,7 @@ def body_fn(carry): ) t0_next = jnp.where(jnp.isfinite(sol.ts), sol.ts, -jnp.inf).max() - y0_next, t0_next, h_next, stats = _handle_event( + y0_next, h_next, stats = _handle_event( t0_next, y0_next, p, @@ -295,11 +295,8 @@ def body_fn(carry): y0_next = sol.ys[1][ -1 ] # next initial state is the last state of the current segment - ts_next = jnp.where( - ts > t0_next, ts, ts[-1] - ).min() # timepoint of next datapoint, don't step over that - y0_next, t0_next, h_next, stats = _handle_event( + y0_next, h_next, stats = _handle_event( t0_next, y0_next, p, @@ -315,7 +312,26 @@ def body_fn(carry): after_event = sol.ts[1] < ts hs = jnp.where(after_event[:, None], h_next[None, :], hs) - return ys, t0_next, y0_next, hs, h_next, stats + # TODO: file issue on diffrax where integration goes the wrong way after + # event trigger - causing event to trigger over and over again + t_resume = t0_next + t_eps + small_step = diffrax.diffeqsolve( + term, + solver, + t0_next, + t0_next + t_eps, + dt0=None, + y0=y0_next, + stepsize_controller=controller, + args=(p, tcl, h), + saveat=diffrax.SaveAt(t1=True), + event=None, + ) + + t_resume = small_step.ts[-1] + y_resume = small_step.ys[-1] + + return ys, t_resume, y_resume, hs, h_next, stats # run the loop until we have reached the end of the time points ys, _, _, hs, _, stats = eqxi.while_loop( @@ -475,7 +491,7 @@ def _handle_event( h_next, ) - return y0_next, t0_next, h_next, stats + return y0_next, h_next, stats def _apply_event_assignments( roots_found, diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 77d2c955e2..e721835555 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -940,11 +940,12 @@ def _handle_t0_event( delta_x: Callable, stats: dict, ): - h = self.event_initial_values + rf0 = self.event_initial_values - 0.5 + h = jnp.heaviside(rf0, 0.0) args = (p, tcl, h) rfx = root_cond_fn(t0_next, y0_next, args) - roots_dir = jnp.sign(rfx - h) - roots_found = (rfx - h) == 0.0 + roots_dir = jnp.sign(rfx - rf0) + roots_found = jnp.sign(rfx) != jnp.sign(rf0) y0_next, h_next = _apply_event_assignments( roots_found, @@ -958,12 +959,12 @@ def _handle_t0_event( if os.getenv("JAX_DEBUG") == "1": jax.debug.print( - "h: {}, rfx: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}", + "h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}", h, + rf0, rfx, roots_found, roots_dir, - h, h_next, ) From 3b9471ecff2a911ad9feedb10fc18e03836e72a2 Mon Sep 17 00:00:00 2001 From: BSnelling Date: Wed, 10 Dec 2025 09:31:20 +0000 Subject: [PATCH 17/28] Update python/sdist/amici/jax/_simulation.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Fabian Fröhlich --- python/sdist/amici/jax/_simulation.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index f65e896d3c..ef702eb17c 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -508,16 +508,11 @@ def _apply_event_assignments( h, ) # update heaviside variables based on the root condition function - mask = jnp.array( - [ - (roots_found & (roots_dir > 0.0) & (h == 0.0)) - for _ in range(y0_next.shape[0]) - ] - ).T + mask = roots_found & (roots_dir > 0.0) & (h == 0.0) delx = delta_x(y0_next, p, tcl) if y0_next.size: delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) - y0_up = jnp.where(mask, delx, 0.0) + y0_up = jnp.where(mask[None, :], delx, 0.0) y0_next = y0_next + jnp.sum(y0_up, axis=0) return y0_next, h_next From b93d8846b6b21a10c877d21e28202f47f1ecf26d Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Wed, 10 Dec 2025 09:33:43 +0000 Subject: [PATCH 18/28] Revert "Update python/sdist/amici/jax/_simulation.py" This reverts commit 82caa9408f6faee0fa482470c001798efaa462c6. --- python/sdist/amici/jax/_simulation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index ef702eb17c..f65e896d3c 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -508,11 +508,16 @@ def _apply_event_assignments( h, ) # update heaviside variables based on the root condition function - mask = roots_found & (roots_dir > 0.0) & (h == 0.0) + mask = jnp.array( + [ + (roots_found & (roots_dir > 0.0) & (h == 0.0)) + for _ in range(y0_next.shape[0]) + ] + ).T delx = delta_x(y0_next, p, tcl) if y0_next.size: delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) - y0_up = jnp.where(mask[None, :], delx, 0.0) + y0_up = jnp.where(mask, delx, 0.0) y0_next = y0_next + jnp.sum(y0_up, axis=0) return y0_next, h_next From a4be71809cc7113a3e8f3bfd85e1a96e428e1cc2 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Wed, 10 Dec 2025 10:36:42 +0000 Subject: [PATCH 19/28] rm clip controller --- python/sdist/amici/jax/_simulation.py | 37 +-------------------------- 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index f65e896d3c..7b19e61517 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -89,7 +89,6 @@ def eq( [None], diffrax.SaveAt(t1=True), term, - known_discs, dict(**STARTING_STATS), ) y1 = jnp.where( @@ -126,7 +125,6 @@ def body_fn(carry): [None] + [True] * len(root_cond_fns), diffrax.SaveAt(t1=True), term, - known_discs, stats, ) y0_next = jnp.where( @@ -188,7 +186,6 @@ def solve( root_cond_fn: Callable, delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], - t_eps: jt.Float = 1e-5, ) -> tuple[jt.Float[jt.Array, "nt nxs"], jt.Float[jt.Array, "nt ne"], dict]: """ Simulate the ODE system for the specified timepoints. @@ -241,7 +238,6 @@ def solve( [], diffrax.SaveAt(ts=ts), term, - known_discs, dict(**STARTING_STATS), ) return sol.ys, jnp.repeat(h[None, :], sol.ys.shape[0]), stats @@ -281,7 +277,6 @@ def body_fn(carry): ] ), term, - known_discs, stats, ) # update the solution for all timepoints in the simulated segment @@ -312,26 +307,7 @@ def body_fn(carry): after_event = sol.ts[1] < ts hs = jnp.where(after_event[:, None], h_next[None, :], hs) - # TODO: file issue on diffrax where integration goes the wrong way after - # event trigger - causing event to trigger over and over again - t_resume = t0_next + t_eps - small_step = diffrax.diffeqsolve( - term, - solver, - t0_next, - t0_next + t_eps, - dt0=None, - y0=y0_next, - stepsize_controller=controller, - args=(p, tcl, h), - saveat=diffrax.SaveAt(t1=True), - event=None, - ) - - t_resume = small_step.ts[-1] - y_resume = small_step.ys[-1] - - return ys, t_resume, y_resume, hs, h_next, stats + return ys, t0_next, y0_next, hs, h_next, stats # run the loop until we have reached the end of the time points ys, _, _, hs, _, stats = eqxi.while_loop( @@ -368,7 +344,6 @@ def _run_segment( cond_dirs: list[None | bool], saveat: diffrax.SaveAt, term: diffrax.ODETerm, - known_discs: jt.Float[jt.Array, "*nediscs"], stats: dict, ) -> tuple[diffrax.Solution, int, dict]: """Solve a single integration segment and return triggered event index, start time for the next segment, @@ -390,16 +365,6 @@ def _run_segment( else None ) - # manage events with explicit discontinuities - controller = ( - diffrax.ClipStepSizeController( - controller, - jump_ts=known_discs, - ) - if known_discs.size - else controller - ) - sol = diffrax.diffeqsolve( term, solver, From 0ab7c685e95561870ee75632beb638ed99c729c9 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Wed, 10 Dec 2025 11:26:02 +0000 Subject: [PATCH 20/28] handle t0 event near zero --- python/sdist/amici/jax/model.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index e721835555..7a1636c651 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -633,6 +633,7 @@ def simulate_condition_unjitted( x_solver, p, tcl, + root_finder, self._root_cond_fn, self._delta_x, {}, @@ -906,6 +907,7 @@ def preequilibrate_condition( self._x_solver(x0), p, tcl, + root_finder, self._root_cond_fn, self._delta_x, {}, @@ -936,6 +938,7 @@ def _handle_t0_event( y0_next: jt.Float[jt.Array, "nxs"], p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], + root_finder: AbstractRootFinder, root_cond_fn: Callable, delta_x: Callable, stats: dict, @@ -957,6 +960,23 @@ def _handle_t0_event( delta_x, ) + roots_zero = jnp.isclose( + rfx, 0.0, atol=root_finder.atol, rtol=root_finder.rtol + ) + droot_dt = ( + # ∂root_cond_fn/∂t + jax.jacfwd(root_cond_fn, argnums=0)(t0_next, y0_next, args) + + + # ∂root_cond_fn/∂y * ∂y/∂t + jax.jacfwd(root_cond_fn, argnums=1)(t0_next, y0_next, args) + @ self._xdot(t0_next, y0_next, args) + ) + h_next = jnp.where( + roots_zero, + droot_dt >= 0.0, + h_next, + ) + if os.getenv("JAX_DEBUG") == "1": jax.debug.print( "h: {}, rf0: {}, rfx: {}, roots_found: {}, roots_dir: {}, h_next: {}", From 4a47d3be155d7dd3f2a0f1808cededffb99b5fd7 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Wed, 10 Dec 2025 17:55:09 +0000 Subject: [PATCH 21/28] skip non-time dependent event assignment cases --- python/sdist/amici/_symbolic/de_model.py | 10 ++++++++++ python/sdist/amici/jax/ode_export.py | 5 +++++ tests/benchmark_models/test_petab_benchmark_jax.py | 5 +++++ 3 files changed, 20 insertions(+) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 836a604ace..91976e8f82 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2691,6 +2691,16 @@ def has_implicit_event_assignments(self) -> bool: """ allowed_syms = set(self.sym("p")) return any(event.updates_state and not event.has_explicit_trigger_times(allowed_syms) for event in self._events) + + def has_only_time_dependent_event_assignments(self) -> bool: + """ + Checks whether the model has only time dependent event assignments + + :return: + boolean indicating if solely event assignments with explicit time dependent + triggers are present + """ + return all(len(event.get_trigger_times()) > 0 for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 5b7e750f56..a5457c7e74 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -149,6 +149,11 @@ def __init__( raise NotImplementedError( "The JAX backend does not support models with algebraic states." ) + + if not ode_model.has_only_time_dependent_event_assignments(): + raise NotImplementedError( + "The JAX backend does not support event assignments with explicit non-time dependent triggers." + ) if ode_model.has_priority_events(): raise NotImplementedError( diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 61fc26dd2d..2eff9dc3fc 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -40,6 +40,11 @@ def test_jax_llh(benchmark_problem): "Skipping Smith_BMCSystBiol2013 due to non-supported events in JAX." ) + if problem_id == "Oliveira_NatCommun2021": + pytest.skip( + "Skipping Oliveira_NatCommun2021 due to non-supported events in JAX." + ) + amici_solver = amici_model.create_solver() cur_settings = settings[problem_id] amici_solver.set_absolute_tolerance(1e-8) From bc51bb4402a62a16d7db6892f191b9be4ff6c5e5 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Thu, 11 Dec 2025 09:16:16 +0000 Subject: [PATCH 22/28] fix sbml _symbols --- python/sdist/amici/importers/sbml/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 9e46f63363..88a5283f97 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -1886,7 +1886,7 @@ def _process_events(self) -> None: # Add a negative event for JAX models to handle neg_event_id = event_id + "_negative" neg_event_sym = sp.Symbol(neg_event_id) - self.symbols[SymbolId.EVENT][neg_event_sym] = { + self._symbols[SymbolId.EVENT][neg_event_sym] = { "name": neg_event_id, "value": -trigger, "assignments": None, From 449ed7888680b026f2b865328e7a6307a906fda1 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Thu, 11 Dec 2025 13:39:56 +0000 Subject: [PATCH 23/28] skip some more tests - NotImplemented discs --- python/tests/test_jax.py | 141 +++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 64 deletions(-) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index dccc5a8ae0..dcb8b91f74 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -343,40 +343,47 @@ def test_time_dependent_discontinuity(tmp_path): sbml = antimony2sbml(ant_model) importer = SbmlImporter(sbml, from_file=False) - importer.sbml2jax("time_disc", output_dir=tmp_path) - - module = amici._module_from_path("time_disc", tmp_path / "__init__.py") - model = module.Model() - - p = jnp.array([1.0]) - x0_full = model._x0(0.0, p) - tcl = model._tcl(x0_full, p) - x0 = model._x_solver(x0_full) - ts = jnp.array([0.0, 1.0, 2.0]) - h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) - - assert len(model._root_cond_fns()) > 0 - assert model._known_discs(p).size == 0 - - ys, _, _ = solve( - p, - ts, - tcl, - h, - x0, - diffrax.Tsit5(), - diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), - optimistix.Newton(atol=1e-8, rtol=1e-8), - 1000, - diffrax.DirectAdjoint(), - diffrax.ODETerm(model._xdot), - model._root_cond_fns(), - model._root_cond_fn, - model._delta_x, - model._known_discs(p), - ) - assert ys.shape[0] == ts.shape[0] + try: + importer.sbml2jax("time_disc", output_dir=tmp_path) + + module = amici._module_from_path("time_disc", tmp_path / "__init__.py") + model = module.Model() + + p = jnp.array([1.0]) + x0_full = model._x0(0.0, p) + tcl = model._tcl(x0_full, p) + x0 = model._x_solver(x0_full) + ts = jnp.array([0.0, 1.0, 2.0]) + h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) + + assert len(model._root_cond_fns()) > 0 + assert model._known_discs(p).size == 0 + + ys, _, _ = solve( + p, + ts, + tcl, + h, + x0, + diffrax.Tsit5(), + diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), + optimistix.Newton(atol=1e-8, rtol=1e-8), + 1000, + diffrax.DirectAdjoint(), + diffrax.ODETerm(model._xdot), + model._root_cond_fns(), + model._root_cond_fn, + model._delta_x, + model._known_discs(p), + ) + + assert ys.shape[0] == ts.shape[0] + + except NotImplementedError as err: + if "The JAX backend does not support" in str(err): + pytest.skip(str(err)) + raise err @skip_on_valgrind @@ -397,35 +404,41 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): sbml = antimony2sbml(ant_model) importer = SbmlImporter(sbml, from_file=False) - importer.sbml2jax("time_disc_eq", output_dir=tmp_path) - - module = amici._module_from_path("time_disc_eq", tmp_path / "__init__.py") - model = module.Model() - - p = jnp.array([1.0]) - x0_full = model._x0(0.0, p) - tcl = model._tcl(x0_full, p) - x0 = model._x_solver(x0_full) - h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) - - assert len(model._root_cond_fns()) > 0 - assert model._known_discs(p).size == 0 - - xs, _, _ = eq( - p, - tcl, - h, - x0, - diffrax.Tsit5(), - diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), - optimistix.Newton(atol=1e-8, rtol=1e-8), - diffrax.steady_state_event(rtol=1e-8, atol=1e-8), - diffrax.ODETerm(model._xdot), - model._root_cond_fns(), - model._root_cond_fn, - model._delta_x, - model._known_discs(p), - 1000, - ) + try: + importer.sbml2jax("time_disc_eq", output_dir=tmp_path) + + module = amici._module_from_path("time_disc_eq", tmp_path / "__init__.py") + model = module.Model() + + p = jnp.array([1.0]) + x0_full = model._x0(0.0, p) + tcl = model._tcl(x0_full, p) + x0 = model._x_solver(x0_full) + h = model._initialise_heaviside_variables(0.0, model._x_solver(x0), p, tcl) + + assert len(model._root_cond_fns()) > 0 + assert model._known_discs(p).size == 0 + + xs, _, _ = eq( + p, + tcl, + h, + x0, + diffrax.Tsit5(), + diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), + optimistix.Newton(atol=1e-8, rtol=1e-8), + diffrax.steady_state_event(rtol=1e-8, atol=1e-8), + diffrax.ODETerm(model._xdot), + model._root_cond_fns(), + model._root_cond_fn, + model._delta_x, + model._known_discs(p), + 1000, + ) + + assert_allclose(xs[0], 0.0, atol=1e-2) - assert_allclose(xs[0], 0.0, atol=1e-2) + except NotImplementedError as err: + if "The JAX backend does not support" in str(err): + pytest.skip(str(err)) + raise err From e9f47e30dcafb2207bbfffd936814bbd7c5d2e31 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 12 Dec 2025 09:19:10 +0000 Subject: [PATCH 24/28] update implicit check and skip SalazarCavazos_MBoC2020 benchmark --- python/sdist/amici/_symbolic/de_model.py | 13 +------------ python/sdist/amici/jax/ode_export.py | 5 ----- tests/benchmark_models/test_petab_benchmark_jax.py | 5 +++++ 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 91976e8f82..15d9d32ca7 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2689,18 +2689,7 @@ def has_implicit_event_assignments(self) -> bool: :return: boolean indicating if event assignments with implicit triggers are present """ - allowed_syms = set(self.sym("p")) - return any(event.updates_state and not event.has_explicit_trigger_times(allowed_syms) for event in self._events) - - def has_only_time_dependent_event_assignments(self) -> bool: - """ - Checks whether the model has only time dependent event assignments - - :return: - boolean indicating if solely event assignments with explicit time dependent - triggers are present - """ - return all(len(event.get_trigger_times()) > 0 for event in self._events) + return any(event.updates_state and not event.has_explicit_trigger_times() for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index a5457c7e74..5b7e750f56 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -149,11 +149,6 @@ def __init__( raise NotImplementedError( "The JAX backend does not support models with algebraic states." ) - - if not ode_model.has_only_time_dependent_event_assignments(): - raise NotImplementedError( - "The JAX backend does not support event assignments with explicit non-time dependent triggers." - ) if ode_model.has_priority_events(): raise NotImplementedError( diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 2eff9dc3fc..3ace99ac86 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -45,6 +45,11 @@ def test_jax_llh(benchmark_problem): "Skipping Oliveira_NatCommun2021 due to non-supported events in JAX." ) + if problem_id == "SalazarCavazos_MBoC2020": + pytest.skip( + "Skipping SalazarCavazos_MBoC2020 due to non-supported events in JAX." + ) + amici_solver = amici_model.create_solver() cur_settings = settings[problem_id] amici_solver.set_absolute_tolerance(1e-8) From 3618db84a65134fcec25048e927e4666f92c354f Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 12 Dec 2025 09:54:25 +0000 Subject: [PATCH 25/28] empty set is not None --- python/sdist/amici/_symbolic/de_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 15d9d32ca7..9d43aad13d 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2689,7 +2689,7 @@ def has_implicit_event_assignments(self) -> bool: :return: boolean indicating if event assignments with implicit triggers are present """ - return any(event.updates_state and not event.has_explicit_trigger_times() for event in self._events) + return any(event.updates_state and not event.has_explicit_trigger_times({}) for event in self._events) def toposort_expressions( self, reorder: bool = True From 900eaefab6dda71fbaa93be3f4ab01b71e576550 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 12 Dec 2025 14:44:52 +0000 Subject: [PATCH 26/28] update solver settings for jax benchmarks --- tests/benchmark_models/test_petab_benchmark_jax.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 3ace99ac86..19aa30e482 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -23,6 +23,8 @@ settings, ) +import diffrax + jax.config.update("jax_enable_x64", True) @@ -115,7 +117,14 @@ def test_jax_llh(benchmark_problem): beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True - )(jax_problem) + )( + jax_problem, + max_steps = 2 * 10**5, + controller=diffrax.PIDController( + atol=cur_settings.atol_sim, + rtol=cur_settings.rtol_sim, + ) + ) else: llh_jax, _ = beartype(run_simulations)(jax_problem) From 51a7e9c70f5fef71f5407c29a32a23b59963dbd6 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 12 Dec 2025 16:35:29 +0000 Subject: [PATCH 27/28] keep Weber settings specific --- tests/benchmark_models/test_petab_benchmark_jax.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 19aa30e482..a9b93218c6 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -114,15 +114,23 @@ def test_jax_llh(benchmark_problem): ) if problem_id in problems_for_gradient_check: + if problem_id == "Weber_BMC2015": + atol = cur_settings.atol_sim + rtol = cur_settings.rtol_sim + max_steps = 2 * 10**5 + else: + atol = 1e-8 + rtol = 1e-8 + max_steps = 1024 beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True )( jax_problem, - max_steps = 2 * 10**5, + max_steps=max_steps, controller=diffrax.PIDController( - atol=cur_settings.atol_sim, - rtol=cur_settings.rtol_sim, + atol=atol, + rtol=rtol, ) ) else: From d49572e9df05c8deae00455204384414284710a3 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 15 Dec 2025 12:10:58 +0000 Subject: [PATCH 28/28] review comments - remove x_old usage and add TODO/FIXMEs --- python/sdist/amici/_symbolic/de_model.py | 5 ----- python/sdist/amici/importers/sbml/__init__.py | 2 ++ python/sdist/amici/jax/jax.template.py | 4 ++-- python/sdist/amici/jax/ode_export.py | 5 ++--- tests/benchmark_models/test_petab_benchmark_jax.py | 14 +++----------- 5 files changed, 9 insertions(+), 21 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 9d43aad13d..cb186c4023 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -1596,11 +1596,6 @@ def _compute_equation(self, name: str) -> None: self._eqs[name] = event_eqs - elif name == "x_old": - self._eqs[name] = sp.Matrix( - [state.get_x_rdata() for state in self.states()] - ) - elif name == "z": event_observables = [ sp.zeros(self.num_eventobs(), 1) for _ in self._events diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 88a5283f97..3da3513be7 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -1884,6 +1884,8 @@ def _process_events(self) -> None: if self.jax: # Add a negative event for JAX models to handle + # TODO: remove once condition function directions can be + # traced through diffrax solve neg_event_id = event_id + "_negative" neg_event_sym = sp.Symbol(neg_event_id) self._symbols[SymbolId.EVENT][neg_event_sym] = { diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index eaa0c54982..fe0ff12d8d 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -132,8 +132,8 @@ def _delta_x(self, y, p, tcl): TPL_X_SYMS = y TPL_P_SYMS = p TPL_TCL_SYMS = tcl - - TPL_X_OLD_EQ + # FIXME: workaround until state from event time is properly passed + TPL_X_OLD_SYMS = y TPL_DELTAX_EQ diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 5b7e750f56..6cc61dd561 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -209,7 +209,6 @@ def _generate_jax_code(self) -> None: "eroot", "iroot", "deltax", - "x_old", ) sym_names = ( "p", @@ -262,7 +261,7 @@ def _generate_jax_code(self) -> None: "P_VALUES": _jnp_array_str(self.model.val("p")), "ROOTS": _jnp_array_str( { - _parse_trigger_root(root) + _print_trigger_root(root) for e in self.model._events for root in e.get_trigger_times() } @@ -350,7 +349,7 @@ def set_name(self, model_name: str) -> None: self.model_name = model_name -def _parse_trigger_root(root: sp.Expr) -> str: +def _print_trigger_root(root: sp.Expr) -> str: """Convert a trigger root expression into a string representation. :param root: The trigger root expression. diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index a9b93218c6..2b8e265aef 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -37,19 +37,11 @@ def test_jax_llh(benchmark_problem): problem_id, flat_petab_problem, petab_problem, amici_model = ( benchmark_problem ) - if problem_id == "Smith_BMCSystBiol2013": - pytest.skip( - "Skipping Smith_BMCSystBiol2013 due to non-supported events in JAX." - ) - - if problem_id == "Oliveira_NatCommun2021": - pytest.skip( - "Skipping Oliveira_NatCommun2021 due to non-supported events in JAX." - ) - if problem_id == "SalazarCavazos_MBoC2020": + to_skip = ["Smith_BMCSystBiol2013", "Oliveira_NatCommun2021", "SalazarCavazos_MBoC2020"] + if problem_id in to_skip: pytest.skip( - "Skipping SalazarCavazos_MBoC2020 due to non-supported events in JAX." + f"Skipping {problem_id} due to non-supported events in JAX." ) amici_solver = amici_model.create_solver()