From b2c3119cd6581096d8afd884593121c436f49bab Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 12 Nov 2025 22:43:47 +0100 Subject: [PATCH] Faster substitutions with `smart_subs_dict` II Flatten out expressions first to avoid expensive repeated substitutions in matrices. For my current benchmark, this yields another 10x improvement compared to #3025. --- python/sdist/amici/importers/utils.py | 28 ++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/python/sdist/amici/importers/utils.py b/python/sdist/amici/importers/utils.py index a4348783a8..2ad7e32322 100644 --- a/python/sdist/amici/importers/utils.py +++ b/python/sdist/amici/importers/utils.py @@ -422,20 +422,30 @@ def smart_subs_dict( Substituted symbolic expression """ if field is None: - s = [(eid, expr) for eid, expr in subs.items()] + s = list(subs.items()) else: s = [(eid, expr[field]) for eid, expr in subs.items()] - if reverse: + if not reverse: + # counter-intuitive, but we need to reverse the order for reverse=False s.reverse() with sp.evaluate(False): - for old, new in s: - # note that substitution may change free symbols, so we have to do - # this recursively - if sym.has(old): - sym = sym.xreplace({old: new}) - + # The new expressions may themselves contain symbols to be substituted. + # We flatten them out first, so that the substitutions in `sym` can be + # performed simultaneously, which is usually more efficient than + # repeatedly substituting into `sym`. + # TODO(performance): This could probably be made more efficient by + # combining with toposort used to order `subs` in the first place. + # Some substitutions could be combined, and some terms not present in + # `sym` could be skipped. + for i in range(len(s) - 1): + for j in range(i + 1, len(s)): + if s[j][1].has(s[i][0]): + s[j] = s[j][0], s[j][1].xreplace({s[i][0]: s[i][1]}) + + s = dict(s) + sym = sym.xreplace(s) return sym @@ -450,7 +460,7 @@ def smart_subs(element: sp.Expr, old: sp.Symbol, new: sp.Expr) -> sp.Expr: to be substituted :param new: - subsitution value + substitution value :return: substituted expression