Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,18 +1305,36 @@ def parse_events(self) -> None:
and replaces the formulae of the found roots by identifiers of AMICI's
Heaviside function implementation in the right-hand side
"""
# toposorted w_sym -> w_expr for substitution of 'w' in trigger function
# do only once. `w` is not modified during this function.
w_toposorted = toposort_symbols(
dict(
zip(
[expr.get_id() for expr in self._expressions],
[expr.get_val() for expr in self._expressions],
strict=True,
)
)
)

# Track all roots functions in the right-hand side
roots = copy.deepcopy(self._events)
for state in self._differential_states:
state.set_dt(self._process_heavisides(state.get_dt(), roots))
state.set_dt(
self._process_heavisides(state.get_dt(), roots, w_toposorted)
)

for expr in self._expressions:
expr.set_val(self._process_heavisides(expr.get_val(), roots))
expr.set_val(
self._process_heavisides(expr.get_val(), roots, w_toposorted)
)

# remove all possible Heavisides from roots, which may arise from
# the substitution of `'w'` in `_collect_heaviside_roots`
for root in roots:
root.set_val(self._process_heavisides(root.get_val(), roots))
root.set_val(
self._process_heavisides(root.get_val(), roots, w_toposorted)
)

# Now add the found roots to the model components
for root in roots:
Expand All @@ -1326,6 +1344,11 @@ def parse_events(self) -> None:
# add roots of heaviside functions
self.add_component(root)

# Substitute 'w' expressions into root expressions, to avoid rewriting
# 'root.cpp' and 'stau.cpp' headers to include 'w.h'.
for event in self.events():
event.set_val(event.get_val().subs(w_toposorted))

# re-order events - first those that require root tracking, then the others
constant_syms = set(self.sym("k")) | set(self.sym("p"))
self._events = list(
Expand Down Expand Up @@ -2391,7 +2414,7 @@ def _expr_is_time_dependent(self, expr: sp.Expr) -> bool:
expr_syms = {str(sym) for sym in expr.free_symbols}

# Check if the time variable is in the expression.
if "t" in expr_syms:
if amici_time_symbol.name in expr_syms:
return True

# Check if any time-dependent states are in the expression.
Expand Down Expand Up @@ -2464,33 +2487,11 @@ def _collect_heaviside_roots(

return root_funs

def _substitute_w_in_roots(
self,
root_funs: list[tuple[sp.Expr, sp.Expr]],
) -> list[tuple[sp.Expr, sp.Expr]]:
"""
Substitute 'w' expressions into root expressions, to avoid rewriting
'root.cpp' and 'stau.cpp' headers to include 'w.h'.
"""
w_sorted = toposort_symbols(
dict(
zip(
[expr.get_id() for expr in self._expressions],
[expr.get_val() for expr in self._expressions],
strict=True,
)
)
)
root_funs = [
(r[0].subs(w_sorted), r[1].subs(w_sorted)) for r in root_funs
]

return root_funs

def _process_heavisides(
self,
dxdt: sp.Expr,
roots: list[Event],
w_toposorted: dict[sp.Symbol, sp.Expr],
) -> sp.Expr:
"""
Parses the RHS of a state variable, checks for Heaviside functions,
Expand All @@ -2502,7 +2503,8 @@ def _process_heavisides(
right-hand side of state variable
:param roots:
list of known root functions with identifier

:param w_toposorted:
`w` symbols->expressions sorted in topological order
:returns:
dxdt with Heaviside functions replaced by amici helper variables
"""
Expand All @@ -2511,7 +2513,15 @@ def _process_heavisides(
heavisides = []
# run through the expression tree and get the roots
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
tmp_roots_old = self._substitute_w_in_roots(tmp_roots_old)
# substitute 'w' symbols in the root expression by their equations,
# because currently,
# 1) root functions must not depend on 'w'
# 2) the check for time-dependence currently assumes only state
# variables are implicitly time-dependent
tmp_roots_old = [
(a.subs(w_toposorted), b.subs(w_toposorted))
for a, b in tmp_roots_old
]
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
# we want unique identifiers for the roots
tmp_root_new = self._get_unique_root(tmp_root_old, roots)
Expand Down
12 changes: 11 additions & 1 deletion python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@
from sympy.logic.boolalg import BooleanAtom
from toposort import toposort

RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]
RESERVED_SYMBOLS = [
"x",
"k",
"p",
"y",
"w",
"h",
"t",
"AMICI_EMPTY_BOLUS",
"NULL",
]

try:
import pysb
Expand Down
9 changes: 8 additions & 1 deletion python/tests/test_bngl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import amici
import numpy as np
import pytest

Expand All @@ -10,6 +9,7 @@
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from pysb.importers.bngl import model_from_bngl
from pysb.simulator import ScipyOdeSimulator
from contextlib import suppress

tests = [
"CaOscillate_Func",
Expand Down Expand Up @@ -39,6 +39,13 @@
@skip_on_valgrind
@pytest.mark.parametrize("example", tests)
def test_compare_to_pysb_simulation(example):
import amici.import_utils

# allow "NULL" as model symbol
# (used in CaOscillate_Func and Repressilator examples)
with suppress(ValueError):
amici.import_utils.RESERVED_SYMBOLS.remove("NULL")

atol = 1e-8
rtol = 1e-8

Expand Down
Loading