diff --git a/pytest.ini b/pytest.ini index 481b60e04f..5a478bc1f9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -30,4 +30,4 @@ filterwarnings = ignore:jax.* is deprecated:DeprecationWarning -norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples +norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples *build* diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 4aaead9770..538bd33f69 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -454,14 +454,6 @@ def get_rate(symbol: sp.Symbol): ) self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) - # replace rateOf-instances in x0 by xdot equation - for i_state in range(len(self.eq("x0"))): - new, replacement = self._eqs["x0"][i_state].replace( - rate_of_func, get_rate, map=True - ) - if replacement: - self._eqs["x0"][i_state] = new - # replace rateOf-instances in w by xdot equation # here we may need toposort, as xdot may depend on w made_substitutions = False @@ -509,6 +501,30 @@ def get_rate(symbol: sp.Symbol): self._syms["w"] = sp.Matrix(topo_expr_syms) self._eqs["w"] = sp.Matrix(list(w_sorted.values())) + # replace rateOf-instances in x0 by xdot equation + # indices of state variables whose x0 was modified + changed_indices = [] + for i_state in range(len(self.eq("x0"))): + new, replacement = self._eqs["x0"][i_state].replace( + rate_of_func, get_rate, map=True + ) + if replacement: + self._eqs["x0"][i_state] = new + changed_indices.append(i_state) + if changed_indices: + # Replace any newly introduced state variables + # by their x0 expressions. + # Also replace any newly introduced `w` symbols by their + # expressions (after `w` was toposorted above). + subs = toposort_symbols( + dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True)) + ) + subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs + for i_state in changed_indices: + self._eqs["x0"][i_state] = smart_subs_dict( + self._eqs["x0"][i_state], subs + ) + for component in chain( self.observables(), self.events(), diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 8d4db42e91..ec827a5ed4 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -1386,13 +1386,20 @@ def _process_parameters( # Parameters that need to be turned into expressions or species # so far, this concerns parameters with symbolic initial assignments # (those have been skipped above) that are not rate rule targets + + # Set of symbols in initial assignments that still allows handling them + # via amici expressions + syms_allowed_in_expr_ia = set(self.symbols[SymbolId.PARAMETER]) | set( + self.symbols[SymbolId.FIXED_PARAMETER] + ) + for par in self.sbml.getListOfParameters(): if ( (ia := par_id_to_ia.get(par.getId())) is not None and not ia.is_Number and not self.is_rate_rule_target(par) ): - if not ia.has(sbml_time_symbol): + if not (ia.free_symbols - syms_allowed_in_expr_ia): self.symbols[SymbolId.EXPRESSION][ _get_identifier_symbol(par) ] = { @@ -1407,6 +1414,10 @@ def _process_parameters( # We can't represent that as expression, since the # initial simulation time is only known at the time of the # simulation, so we can't substitute it. + # Also, any parameter with an initial assignment + # that expression that is implicitly time-dependent + # must be converted to a species to avoid re-evaluating + # the initial assignment at every time step. self.symbols[SymbolId.SPECIES][ _get_identifier_symbol(par) ] = { @@ -1515,13 +1526,36 @@ def _process_rules(self) -> None: self.symbols[SymbolId.EXPRESSION], "value" ) - # expressions must not occur in definition of x0 + # expressions must not occur in the definition of x0 + allowed_syms = ( + set(self.symbols[SymbolId.PARAMETER]) + | set(self.symbols[SymbolId.FIXED_PARAMETER]) + | {sbml_time_symbol} + ) for species in self.symbols[SymbolId.SPECIES].values(): - species["init"] = self._make_initial( - smart_subs_dict( - species["init"], self.symbols[SymbolId.EXPRESSION], "value" + # only parameters are allowed as free symbols + while True: + species["init"] = species["init"].subs(self.compartments) + sym_math, rateof_to_dummy = _rateof_to_dummy(species["init"]) + old_init = species["init"] + if ( + sym_math.free_symbols + - allowed_syms + - set(rateof_to_dummy.values()) + == set() + ): + break + species["init"] = self._make_initial( + smart_subs_dict( + species["init"], + self.symbols[SymbolId.EXPRESSION], + "value", + ) ) - ) + if species["init"] == old_init: + raise AssertionError( + f"Infinite loop detected in _process_rules {species}." + ) def _process_rule_algebraic(self, rule: libsbml.AlgebraicRule): formula = self._sympify(rule) @@ -2359,6 +2393,10 @@ def _make_initial( sym_math = sym_math.subs( var, self.symbols[SymbolId.SPECIES][var]["init"] ) + elif var in self.symbols[SymbolId.ALGEBRAIC_STATE]: + sym_math = sym_math.subs( + var, self.symbols[SymbolId.ALGEBRAIC_STATE][var]["init"] + ) elif ( element := self.sbml.getElementBySId(element_id) ) and self.is_rate_rule_target(element): diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index eaa6896fab..3f124f8e11 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -5,16 +5,16 @@ import sys from numbers import Number from pathlib import Path - import amici import libsbml import numpy as np import pytest from amici.gradient_check import check_derivatives -from amici.sbml_import import SbmlImporter -from amici.testing import skip_on_valgrind +from amici.sbml_import import SbmlImporter, SymbolId +from amici.import_utils import symbol_with_assumptions from numpy.testing import assert_allclose, assert_array_equal from amici import import_model_module +from amici.testing import skip_on_valgrind from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory from conftest import MODEL_STEADYSTATE_SCALED_XML import sympy as sp @@ -1142,3 +1142,42 @@ def test_contains_periodic_subexpression(): assert cps(sp.sin(t), t) is True assert cps(sp.cos(t), t) is True assert cps(t + sp.sin(t), t) is True + + +@skip_on_valgrind +@pytest.mark.parametrize("compute_conservation_laws", [True, False]) +def test_time_dependent_initial_assignment(compute_conservation_laws: bool): + """Check that dynamic expressions for initial assignments are only + evaluated at t=t0.""" + from amici.antimony_import import antimony2sbml + + ant_model = """ + x1' = 1 + x1 = p0 + p0 = 1 + p1 = x1 + x2 := x1 + p2 = x2 + """ + + sbml_model = antimony2sbml(ant_model) + print(sbml_model) + si = SbmlImporter(sbml_model, from_file=False) + de_model = si._build_ode_model( + observables={"obs_p1": {"formula": "p1"}, "obs_p2": {"formula": "p2"}}, + compute_conservation_laws=compute_conservation_laws, + ) + # "species", because the initial assignment expression is time-dependent + assert symbol_with_assumptions("p2") in si.symbols[SymbolId.SPECIES].keys() + # "species", because differential state + assert symbol_with_assumptions("x1") in si.symbols[SymbolId.SPECIES].keys() + + assert "p0" in [str(p.get_id()) for p in de_model.parameters()] + assert "p1" not in [str(p.get_id()) for p in de_model.parameters()] + assert "p2" not in [str(p.get_id()) for p in de_model.parameters()] + + assert list(de_model.sym("x_rdata")) == [ + symbol_with_assumptions("p2"), + symbol_with_assumptions("x1"), + ] + assert list(de_model.eq("x0")) == [symbol_with_assumptions("p0")] * 2