Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
rev: v0.15.2
hooks:
# Run the linter.
- id: ruff
Expand Down
6 changes: 3 additions & 3 deletions doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"metadata": {},
"outputs": [],
"source": [
"import petab.v1 as petab\n",
"from amici.importers.petab import *\n",
"from petab.v2 import Problem\n",
"\n",
Expand Down Expand Up @@ -403,7 +402,9 @@
"nps = jax_problem._np_numeric[ic, :]\n",
"\n",
"# Load parameters for the specified condition\n",
"p = jax_problem.load_model_parameters(jax_problem._petab_problem.experiments[0], is_preeq=False)\n",
"p = jax_problem.load_model_parameters(\n",
" jax_problem._petab_problem.experiments[0], is_preeq=False\n",
")\n",
"\n",
"\n",
"# Define a function to compute the gradient with respect to dynamic timepoints\n",
Expand Down Expand Up @@ -576,7 +577,6 @@
"outputs": [],
"source": [
"from amici.sim.sundials import SensitivityMethod, SensitivityOrder\n",
"from amici.sim.sundials.petab.v1 import simulate_petab\n",
"\n",
"# Import the PEtab problem as a standard AMICI model\n",
"pi = PetabImporter(\n",
Expand Down
14 changes: 9 additions & 5 deletions python/sdist/amici/_symbolic/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2648,7 +2648,7 @@ def get_implicit_roots(self) -> list[sp.Expr]:
for e in self._events
if not e.has_explicit_trigger_times()
]

def has_algebraic_states(self) -> bool:
"""
Checks whether the model has algebraic states
Expand All @@ -2666,7 +2666,7 @@ def has_event_assignments(self) -> bool:
boolean indicating if event assignments are present
"""
return any(event.updates_state for event in self._events)

def has_priority_events(self) -> bool:
"""
Checks whether the model has events with priorities defined
Expand All @@ -2675,7 +2675,7 @@ def has_priority_events(self) -> bool:
boolean indicating if priority events are present
"""
return any(event.get_priority() is not None for event in self._events)

def has_implicit_event_assignments(self) -> bool:
"""
Checks whether the model has event assignments with implicit triggers
Expand All @@ -2686,9 +2686,13 @@ def has_implicit_event_assignments(self) -> bool:
"""
fixed_symbols = set([k._symbol for k in self._fixed_parameters])
allowed_symbols = fixed_symbols | {amici_time_symbol}
# TODO: update to use has_explicit_trigger_times once
# TODO: update to use has_explicit_trigger_times once
# https://github.com/AMICI-dev/AMICI/issues/3126 is resolved
return any(event.updates_state and event._has_implicit_triggers(allowed_symbols) for event in self._events)
return any(
event.updates_state
and event._has_implicit_triggers(allowed_symbols)
for event in self._events
)

def toposort_expressions(
self, reorder: bool = True
Expand Down
9 changes: 4 additions & 5 deletions python/sdist/amici/_symbolic/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def __init__(
:param priority: The priority of the event assignment.

:param is_negative_event:
Whether this event is a "negative" event, i.e., an event that is
Whether this event is a "negative" event, i.e., an event that is
added to mirror an existing event with inverted trigger condition
to avoid immediate retriggering of the original event (JAX simulations).

Expand Down Expand Up @@ -863,17 +863,16 @@ def has_explicit_trigger_times(
"""
if allowed_symbols is None:
return len(self._t_root) > 0

return len(self._t_root) > 0 and all(
t.is_Number or t.free_symbols.issubset(allowed_symbols)
for t in self._t_root
)

def _has_implicit_triggers(
self, allowed_symbols: set[sp.Symbol] | None = None
) -> bool:
"""Check whether the event has implicit triggers.
"""
"""Check whether the event has implicit triggers."""
t = self.get_val()
return not t.free_symbols.issubset(allowed_symbols)

Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum


class SymbolId(str, enum.Enum):
class SymbolId(enum.StrEnum):
"""
Defines the different fields in the symbol dict to which sbml entities
get parsed to.
Expand Down
4 changes: 2 additions & 2 deletions python/sdist/amici/importers/petab/_petab_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from amici import get_model_dir
from amici._symbolic import DEModel, Event
from amici.importers.utils import MeasurementChannel, amici_time_symbol
from amici.logging import get_logger
from amici.jax.petab import JAXProblem
from amici.logging import get_logger

from .v1.sbml_import import _add_global_parameter

Expand Down Expand Up @@ -608,7 +608,7 @@ def create_simulator(
model_module = self.import_module(force_import=force_import)
model = model_module.Model()
return JAXProblem(model, self.petab_problem)

model = self.import_module(force_import=force_import).get_model()
em = ExperimentManager(model=model, petab_problem=self.petab_problem)
return PetabSimulator(em=em)
Expand Down
3 changes: 0 additions & 3 deletions python/sdist/amici/importers/pysb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import itertools
import logging
import os
import re
import sys
from collections.abc import Callable, Iterable
from pathlib import Path
Expand All @@ -31,9 +30,7 @@
FixedParameter,
FreeParameter,
LogLikelihoodY,
NoiseParameter,
Observable,
ObservableParameter,
SigmaY,
)
from amici.logging import get_logger, log_execution_time, set_log_level
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/importers/sbml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,7 +1902,7 @@ def _process_events(self) -> None:

if self.jax:
# Add a negative event for JAX models to handle
# TODO: remove once condition function directions can be
# TODO: remove once condition function directions can be
# traced through diffrax solve
neg_event_id = event_id + "_negative"
neg_event_sym = sp.Symbol(neg_event_id)
Expand Down
7 changes: 3 additions & 4 deletions python/sdist/amici/importers/sbml/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,10 +1511,9 @@ def _spline_user_functions(
"AmiciSplineSensitivity": [
(
lambda *args: True,
lambda spline_id,
x,
param_id,
*p: f"sspl_{spline_ids.index(spline_id)}_{p_index[param_id]}",
lambda spline_id, x, param_id, *p: (
f"sspl_{spline_ids.index(spline_id)}_{p_index[param_id]}"
),
)
],
}
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/importers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, data):
annotation_namespace = "https://github.com/AMICI-dev/AMICI"


class ObservableTransformation(str, enum.Enum):
class ObservableTransformation(enum.StrEnum):
"""
Different modes of observable transformation.
"""
Expand Down
4 changes: 3 additions & 1 deletion python/sdist/amici/jax/_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def _handle_event(

return y0_next, h_next, stats


def _check_cascading_events(
t0_next: float,
y0_next: jt.Float[jt.Array, "nxs"],
Expand Down Expand Up @@ -521,6 +522,7 @@ def _check_cascading_events(

return y0_next


def _apply_event_assignments(
roots_found,
roots_dir,
Expand All @@ -543,7 +545,7 @@ def _apply_event_assignments(
]
).T

# apply one event at a time
# apply one event at a time
if h_next.shape[0] and y0_next.shape[0]:
n_pairs = h_next.shape[0] // 2
inds_seq = jnp.arange(n_pairs)
Expand Down
6 changes: 3 additions & 3 deletions python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,18 @@ def _root_cond_fn(self, t, y, args, **_):
TPL_EROOT_EQ

return jnp.hstack((TPL_IROOT_RET, TPL_EROOT_RET))

def _delta_x(self, y, p, tcl):
TPL_X_SYMS = y
TPL_ALL_P_SYMS = p
TPL_TCL_SYMS = tcl
# FIXME: workaround until state from event time is properly passed
TPL_X_OLD_SYMS = y

TPL_DELTAX_EQ

return TPL_DELTAX_RET

@property
def event_initial_values(self):
return TPL_EVENT_INITIAL_VALUES
Expand Down
Loading
Loading