From 6e041ab17d2571d03ec09818183e210d60495a5d Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 14 Oct 2025 14:26:30 +0200 Subject: [PATCH] Speed up `DEModel._collect_heaviside_roots` Avoid unnecessary repeated toposorting of `w` during `_collect_heaviside_root`. Sort and substitute only once after all roots have been collected. This saves a couple of seconds for models with heavily nested piecewise functions. --- python/sdist/amici/de_model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 975123a000..e423e10ec0 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2462,12 +2462,16 @@ def _collect_heaviside_roots( elif arg.has(sp.Heaviside): root_funs.extend(self._collect_heaviside_roots(arg.args)) - if not root_funs: - return [] + return root_funs - # substitute 'w' expressions into root expressions now, to avoid - # rewriting 'root.cpp' and 'stau.cpp' headers - # to include 'w.h' + def _substitute_w_in_roots( + self, + root_funs: list[tuple[sp.Expr, sp.Expr]], + ) -> list[tuple[sp.Expr, sp.Expr]]: + """ + Substitute 'w' expressions into root expressions, to avoid rewriting + 'root.cpp' and 'stau.cpp' headers to include 'w.h'. + """ w_sorted = toposort_symbols( dict( zip( @@ -2507,6 +2511,7 @@ def _process_heavisides( heavisides = [] # run through the expression tree and get the roots tmp_roots_old = self._collect_heaviside_roots((dxdt,)) + tmp_roots_old = self._substitute_w_in_roots(tmp_roots_old) for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old): # we want unique identifiers for the roots tmp_root_new = self._get_unique_root(tmp_root_old, roots)