Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
cast_to_sym,
generate_measurement_symbol,
generate_regularization_symbol,
contains_periodic_subexpression,
)
from .constants import SymbolId

Expand Down Expand Up @@ -753,11 +754,18 @@ def __init__(
self._use_values_from_trigger_time = use_values_from_trigger_time

# expression(s) for the timepoint(s) at which the event triggers
try:
self._t_root = sp.solve(self.get_val(), amici_time_symbol)
except NotImplementedError:
# the trigger can't be solved for `t`
self._t_root = []
self._t_root = []

if not contains_periodic_subexpression(
self.get_val(), amici_time_symbol
):
# `solve` will solve, e.g., sin(t), but will only return [0, pi],
# so we better skip any periodic expressions here
try:
self._t_root = sp.solve(self.get_val(), amici_time_symbol)
except NotImplementedError:
# the trigger can't be solved for `t`
pass

def get_state_update(
self, x: sp.Matrix, x_old: sp.Matrix
Expand Down Expand Up @@ -841,6 +849,13 @@ def get_trigger_times(self) -> set[sp.Expr]:

Returns a set of expressions, which may contain multiple time points
for events that trigger at multiple time points.

If the return value is empty, the trigger function cannot be solved
for `t`. I.e., the event does not explicitly depend on time,
or sympy is unable to solve the trigger function for `t`.

If the return value is non-empty, it contains expressions for *all*
time points at which the event triggers.
"""
return set(self._t_root)

Expand Down
15 changes: 15 additions & 0 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,3 +816,18 @@ def _default_simplify(x):
# We need this as a free function instead of a lambda to have it picklable
# for parallel simplification
return sp.powsimp(x, deep=True)


def contains_periodic_subexpression(expr: sp.Expr, symbol: sp.Symbol) -> bool:
"""
Check if a sympy expression contains any periodic subexpression.

:param expr: The sympy expression to check.
:param symbol: The variable with respect to which periodicity is checked.
:return: `True` if the expression contains a periodic subexpression,
`False` otherwise.
"""
for subexpr in expr.atoms(sp.Function):
if sp.periodicity(subexpr, symbol) is not None:
return True
return False
13 changes: 13 additions & 0 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,3 +1143,16 @@ def test_t0(tempdir):
solver = model.getSolver()
rdata = amici.runAmiciSimulation(model, solver)
assert rdata.x == [[2.0]], rdata.x


@skip_on_valgrind
def test_contains_periodic_subexpression():
"""Test that periodic subexpressions are detected."""
from amici.import_utils import contains_periodic_subexpression as cps

t = sp.Symbol("t")

assert cps(t, t) is False
assert cps(sp.sin(t), t) is True
assert cps(sp.cos(t), t) is True
assert cps(t + sp.sin(t), t) is True
Loading