Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ filterwarnings =
ignore:jax.* is deprecated:DeprecationWarning


norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples
norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples *build*
32 changes: 24 additions & 8 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,6 @@ def get_rate(symbol: sp.Symbol):
)
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)

# replace rateOf-instances in x0 by xdot equation
for i_state in range(len(self.eq("x0"))):
new, replacement = self._eqs["x0"][i_state].replace(
rate_of_func, get_rate, map=True
)
if replacement:
self._eqs["x0"][i_state] = new

# replace rateOf-instances in w by xdot equation
# here we may need toposort, as xdot may depend on w
made_substitutions = False
Expand Down Expand Up @@ -509,6 +501,30 @@ def get_rate(symbol: sp.Symbol):
self._syms["w"] = sp.Matrix(topo_expr_syms)
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))

# replace rateOf-instances in x0 by xdot equation
# indices of state variables whose x0 was modified
changed_indices = []
for i_state in range(len(self.eq("x0"))):
new, replacement = self._eqs["x0"][i_state].replace(
rate_of_func, get_rate, map=True
)
if replacement:
self._eqs["x0"][i_state] = new
changed_indices.append(i_state)
if changed_indices:
# Replace any newly introduced state variables
# by their x0 expressions.
# Also replace any newly introduced `w` symbols by their
# expressions (after `w` was toposorted above).
subs = toposort_symbols(
dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True))
)
subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this has to be done in a loop as dict(zip(self._syms["w"], self.eq("w"), strict=True)) subs may introduce new self.sym("x_rdata") variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w should not depend on x_rdata:

"w": _FunctionInfo(
"realtype *w, const realtype t, const realtype *x, "
"const realtype *p, const realtype *k, "
"const realtype *h, const realtype *tcl, const realtype *spl, "
"bool include_static",
assume_pow_positivity=True,
),

Unless the respective substitutions in w are only made later, but I don't think that is the case - the rateOf processing should happen last.

for i_state in changed_indices:
self._eqs["x0"][i_state] = smart_subs_dict(
self._eqs["x0"][i_state], subs
)

for component in chain(
self.observables(),
self.events(),
Expand Down
50 changes: 44 additions & 6 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,13 +1386,20 @@ def _process_parameters(
# Parameters that need to be turned into expressions or species
# so far, this concerns parameters with symbolic initial assignments
# (those have been skipped above) that are not rate rule targets

# Set of symbols in initial assignments that still allows handling them
# via amici expressions
syms_allowed_in_expr_ia = set(self.symbols[SymbolId.PARAMETER]) | set(
self.symbols[SymbolId.FIXED_PARAMETER]
)

for par in self.sbml.getListOfParameters():
if (
(ia := par_id_to_ia.get(par.getId())) is not None
and not ia.is_Number
and not self.is_rate_rule_target(par)
):
if not ia.has(sbml_time_symbol):
if not (ia.free_symbols - syms_allowed_in_expr_ia):
self.symbols[SymbolId.EXPRESSION][
_get_identifier_symbol(par)
] = {
Expand All @@ -1407,6 +1414,10 @@ def _process_parameters(
# We can't represent that as expression, since the
# initial simulation time is only known at the time of the
# simulation, so we can't substitute it.
# Also, any parameter with an initial assignment
# that expression that is implicitly time-dependent
# must be converted to a species to avoid re-evaluating
# the initial assignment at every time step.
self.symbols[SymbolId.SPECIES][
_get_identifier_symbol(par)
] = {
Expand Down Expand Up @@ -1515,13 +1526,36 @@ def _process_rules(self) -> None:
self.symbols[SymbolId.EXPRESSION], "value"
)

# expressions must not occur in definition of x0
# expressions must not occur in the definition of x0
allowed_syms = (
set(self.symbols[SymbolId.PARAMETER])
| set(self.symbols[SymbolId.FIXED_PARAMETER])
| {sbml_time_symbol}
)
for species in self.symbols[SymbolId.SPECIES].values():
species["init"] = self._make_initial(
smart_subs_dict(
species["init"], self.symbols[SymbolId.EXPRESSION], "value"
# only parameters are allowed as free symbols
while True:
species["init"] = species["init"].subs(self.compartments)
sym_math, rateof_to_dummy = _rateof_to_dummy(species["init"])
old_init = species["init"]
if (
sym_math.free_symbols
- allowed_syms
- set(rateof_to_dummy.values())
== set()
):
break
species["init"] = self._make_initial(
smart_subs_dict(
species["init"],
self.symbols[SymbolId.EXPRESSION],
"value",
)
)
)
if species["init"] == old_init:
raise AssertionError(
f"Infinite loop detected in _process_rules {species}."
)

def _process_rule_algebraic(self, rule: libsbml.AlgebraicRule):
formula = self._sympify(rule)
Expand Down Expand Up @@ -2359,6 +2393,10 @@ def _make_initial(
sym_math = sym_math.subs(
var, self.symbols[SymbolId.SPECIES][var]["init"]
)
elif var in self.symbols[SymbolId.ALGEBRAIC_STATE]:
sym_math = sym_math.subs(
var, self.symbols[SymbolId.ALGEBRAIC_STATE][var]["init"]
)
elif (
element := self.sbml.getElementBySId(element_id)
) and self.is_rate_rule_target(element):
Expand Down
45 changes: 42 additions & 3 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
import sys
from numbers import Number
from pathlib import Path

import amici
import libsbml
import numpy as np
import pytest
from amici.gradient_check import check_derivatives
from amici.sbml_import import SbmlImporter
from amici.testing import skip_on_valgrind
from amici.sbml_import import SbmlImporter, SymbolId
from amici.import_utils import symbol_with_assumptions
from numpy.testing import assert_allclose, assert_array_equal
from amici import import_model_module
from amici.testing import skip_on_valgrind
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory
from conftest import MODEL_STEADYSTATE_SCALED_XML
import sympy as sp
Expand Down Expand Up @@ -1142,3 +1142,42 @@ def test_contains_periodic_subexpression():
assert cps(sp.sin(t), t) is True
assert cps(sp.cos(t), t) is True
assert cps(t + sp.sin(t), t) is True


@skip_on_valgrind
@pytest.mark.parametrize("compute_conservation_laws", [True, False])
def test_time_dependent_initial_assignment(compute_conservation_laws: bool):
"""Check that dynamic expressions for initial assignments are only
evaluated at t=t0."""
from amici.antimony_import import antimony2sbml

ant_model = """
x1' = 1
x1 = p0
p0 = 1
p1 = x1
x2 := x1
p2 = x2
"""

sbml_model = antimony2sbml(ant_model)
print(sbml_model)
si = SbmlImporter(sbml_model, from_file=False)
de_model = si._build_ode_model(
observables={"obs_p1": {"formula": "p1"}, "obs_p2": {"formula": "p2"}},
compute_conservation_laws=compute_conservation_laws,
)
# "species", because the initial assignment expression is time-dependent
assert symbol_with_assumptions("p2") in si.symbols[SymbolId.SPECIES].keys()
# "species", because differential state
assert symbol_with_assumptions("x1") in si.symbols[SymbolId.SPECIES].keys()

assert "p0" in [str(p.get_id()) for p in de_model.parameters()]
assert "p1" not in [str(p.get_id()) for p in de_model.parameters()]
assert "p2" not in [str(p.get_id()) for p in de_model.parameters()]

assert list(de_model.sym("x_rdata")) == [
symbol_with_assumptions("p2"),
symbol_with_assumptions("x1"),
]
assert list(de_model.eq("x0")) == [symbol_with_assumptions("p0")] * 2
Loading