diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index 85af284716..bfea153009 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -13,6 +13,7 @@ cast_to_sym, generate_measurement_symbol, generate_regularization_symbol, + contains_periodic_subexpression, ) from .constants import SymbolId @@ -753,11 +754,18 @@ def __init__( self._use_values_from_trigger_time = use_values_from_trigger_time # expression(s) for the timepoint(s) at which the event triggers - try: - self._t_root = sp.solve(self.get_val(), amici_time_symbol) - except NotImplementedError: - # the trigger can't be solved for `t` - self._t_root = [] + self._t_root = [] + + if not contains_periodic_subexpression( + self.get_val(), amici_time_symbol + ): + # `solve` will solve, e.g., sin(t), but will only return [0, pi], + # so we better skip any periodic expressions here + try: + self._t_root = sp.solve(self.get_val(), amici_time_symbol) + except NotImplementedError: + # the trigger can't be solved for `t` + pass def get_state_update( self, x: sp.Matrix, x_old: sp.Matrix @@ -841,6 +849,13 @@ def get_trigger_times(self) -> set[sp.Expr]: Returns a set of expressions, which may contain multiple time points for events that trigger at multiple time points. + + If the return value is empty, the trigger function cannot be solved + for `t`. I.e., the event does not explicitly depend on time, + or sympy is unable to solve the trigger function for `t`. + + If the return value is non-empty, it contains expressions for *all* + time points at which the event triggers. """ return set(self._t_root) diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 7549eb6f1d..6ca7b1a10b 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -816,3 +816,18 @@ def _default_simplify(x): # We need this as a free function instead of a lambda to have it picklable # for parallel simplification return sp.powsimp(x, deep=True) + + +def contains_periodic_subexpression(expr: sp.Expr, symbol: sp.Symbol) -> bool: + """ + Check if a sympy expression contains any periodic subexpression. + + :param expr: The sympy expression to check. + :param symbol: The variable with respect to which periodicity is checked. + :return: `True` if the expression contains a periodic subexpression, + `False` otherwise. + """ + for subexpr in expr.atoms(sp.Function): + if sp.periodicity(subexpr, symbol) is not None: + return True + return False diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 8d2546f3a5..09e6ad4cc1 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -1143,3 +1143,16 @@ def test_t0(tempdir): solver = model.getSolver() rdata = amici.runAmiciSimulation(model, solver) assert rdata.x == [[2.0]], rdata.x + + +@skip_on_valgrind +def test_contains_periodic_subexpression(): + """Test that periodic subexpressions are detected.""" + from amici.import_utils import contains_periodic_subexpression as cps + + t = sp.Symbol("t") + + assert cps(t, t) is False + assert cps(sp.sin(t), t) is True + assert cps(sp.cos(t), t) is True + assert cps(t + sp.sin(t), t) is True