From 014aed1e98d37c402a131b5402115d675f70abd3 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 24 Nov 2025 09:03:11 +0100 Subject: [PATCH] PySB import: Replace all PySB objects by plain sympy objects when creating the DEModel Ensure that DEModel only contains plain `sympy.Symbol`s, no `pysb.Compmenents`. Also, consistently use `symbol_with_assumptions`. Closes #3035. --- python/sdist/amici/_symbolic/de_model.py | 11 +++ python/sdist/amici/importers/pysb/__init__.py | 90 +++++++++++-------- python/sdist/amici/importers/utils.py | 14 +++ python/tests/test_pysb.py | 11 ++- 4 files changed, 86 insertions(+), 40 deletions(-) diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index 542ee8284d..6ac2650628 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -78,10 +78,21 @@ class DEModel: """ Defines a Differential Equation as set of ModelQuantities. + This class provides general purpose interfaces to compute arbitrary symbolic derivatives that are necessary for model simulation or sensitivity computation. + All occurrences of a symbolic variable with a given name must use the same + assumptions (e.g. real, positive, etc.) throughout the model. Mixing + different assumptions for the same variable name will result in incorrect + derivatives and potentially other errors. + + All symbols in the model are expected to be of type sympy.Symbol. If any + subtypes are used, they must behave identically to sympy.Symbol in all + relevant aspects (e.g. hashing, equality testing, etc.). In particular, + `str(symbol)` is expected to return the same value as `symbol.name`. + :ivar _differential_states: differential state variables diff --git a/python/sdist/amici/importers/pysb/__init__.py b/python/sdist/amici/importers/pysb/__init__.py index 12f7757e6b..e4d7c9bb7a 100644 --- a/python/sdist/amici/importers/pysb/__init__.py +++ b/python/sdist/amici/importers/pysb/__init__.py @@ -38,16 +38,19 @@ ObservableParameter, SigmaY, ) -from amici.importers.utils import ( +from amici.logging import get_logger, log_execution_time, set_log_level + +from ..utils import ( MeasurementChannel, _default_simplify, + _expr_to_amici, _get_str_symbol_identifiers, _parse_special_functions, generate_measurement_symbol, noise_distribution_to_cost_function, noise_distribution_to_observable_transformation, + symbol_with_assumptions, ) -from amici.logging import get_logger, log_execution_time, set_log_level CL_Prototype = dict[str, dict[str, Any]] ConservationLaw = dict[str, dict | str | sp.Basic] @@ -166,7 +169,7 @@ def pysb2amici( model_name: str | None = None, pysb_model_has_obs_and_noise: bool = False, _events: list[Event] = None, -) -> amici.Model | None: +) -> amici.sim.sundials.Model | None: r""" Generate AMICI C++ files for the provided model. @@ -264,7 +267,7 @@ def pysb2amici( cache_simplify=cache_simplify, verbose=verbose, pysb_model_has_obs_and_noise=pysb_model_has_obs_and_noise, - events=_events, + _events=_events, ) from amici.exporters.sundials.de_export import ( @@ -312,7 +315,7 @@ def ode_model_from_pysb_importer( verbose: int | bool = False, jax: bool = False, pysb_model_has_obs_and_noise: bool = False, - events: list[Event] = None, + _events: list[Event] = None, ) -> DEModel: """ Creates an :class:`amici.DEModel` instance from a :class:`pysb.Model` @@ -349,7 +352,7 @@ def ode_model_from_pysb_importer( observables and noise variables added :return: - New DEModel instance according to pysbModel + New DEModel instance according to `model` """ ode = DEModel( @@ -370,7 +373,7 @@ def ode_model_from_pysb_importer( _process_pysb_species(model, ode) _process_pysb_parameters(model, ode, fixed_parameters, jax) if compute_conservation_laws: - if events: + if _events: raise NotImplementedError( "Conservation law computation is not supported for models " "with events. Use `compute_conservation_laws=False`." @@ -389,7 +392,7 @@ def ode_model_from_pysb_importer( pysb_model_has_obs_and_noise, ) - for event in events or []: + for event in _events or []: ode.add_component(event) ode._has_quadratic_nllh = all( @@ -460,7 +463,9 @@ def get_cached_index(symbol, sarray, index_cache): idx = get_cached_index(x_rdata[ix], w, wx_idx) values = dflux_dw_dict - values[(ir, idx)] = sp.diff(rxn["rate"], x_rdata[ix]) + values[(ir, idx)] = sp.diff( + _expr_to_amici(rxn["rate"]), x_rdata[ix] + ) # typically <= 3 free symbols in rate, we already account for # species above so we only need to account for propensity, which @@ -481,8 +486,10 @@ def get_cached_index(symbol, sarray, index_cache): else: continue - idx = get_cached_index(fs, var, idx_cache) - values[(ir, idx)] = sp.diff(rxn["rate"], fs) + idx = get_cached_index( + symbol_with_assumptions(fs.name), var, idx_cache + ) + values[(ir, idx)] = _expr_to_amici(sp.diff(rxn["rate"], fs)) dflux_dx = sp.ImmutableSparseMatrix(n_r, n_x, dflux_dx_dict) dflux_dw = sp.ImmutableSparseMatrix(n_r, n_w, dflux_dw_dict) @@ -514,6 +521,7 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None: DEModel instance """ xdot = sp.Matrix(pysb_model.odes) + xdot = _expr_to_amici(xdot) for ix, specie in enumerate(pysb_model.species): init = sp.sympify("0.0") @@ -527,9 +535,13 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None: else: init = ic.value + init = _expr_to_amici(init) ode_model.add_component( DifferentialState( - sp.Symbol(f"__s{ix}"), f"{specie}", init, xdot[ix] + symbol_with_assumptions(f"__s{ix}"), + f"{specie}", + init, + xdot[ix], ) ) logger.debug("Finished Processing PySB species ") @@ -559,7 +571,7 @@ def _process_pysb_parameters( DEModel instance """ for par in pysb_model.parameters: - args = [par, f"{par.name}"] + args = [symbol_with_assumptions(par.name), par.name] if par.name in fixed_parameters: comp = FixedParameter args.append(par.value) @@ -621,7 +633,6 @@ def _process_pysb_expressions( # we skip them. continue _add_expression( - expr, expr.name, expr.expr, pysb_model, @@ -632,7 +643,6 @@ def _process_pysb_expressions( def _add_expression( - sym: sp.Symbol, name: str, expr: sp.Basic, pysb_model: pysb.Model, @@ -644,9 +654,6 @@ def _add_expression( Adds expressions to the ODE model given and adds observables/sigmas if appropriate - :param sym: - symbol how the expression is referenced in the model - :param name: name of the expression @@ -674,22 +681,25 @@ def _add_expression( component = SigmaY else: component = Expression + expr = _parse_special_functions(_expr_to_amici(expr)) ode_model.add_component( - component(sym, name, _parse_special_functions(expr)) + component(symbol_with_assumptions(name), name, expr) ) if name in observation_model: noise_dist = observation_model[name].noise_distribution - y = sp.Symbol(name) + y = symbol_with_assumptions(name) trafo = noise_distribution_to_observable_transformation(noise_dist) - # note that this is a bit iffy since we are potentially using the same _symbolic identifier in expressions (w) - # and observables (y). This is not a problem as there currently are no model functions that use both. If this - # changes, I would expect symbol redefinition warnings in CPP models and overwriting in JAX models, but as both - # symbols refer to the same _symbolic entity, this should not be a problem (untested) - obs = Observable( - y, name, _parse_special_functions(expr), transformation=trafo - ) + # note that this is a bit iffy since we are potentially using the same + # _symbolic identifier in expressions (w) and observables (y). + # This is not a problem as there currently are no model functions that + # use both. + # If this changes, I would expect symbol redefinition warnings in CPP + # models and overwriting in JAX models, but as both symbols refer to + # the same _symbolic entity, this should not be a problem (untested) + expr = _parse_special_functions(_expr_to_amici(expr)) + obs = Observable(y, name, expr, transformation=trafo) ode_model.add_component(obs) sigma_name = observation_model[name].sigma @@ -714,9 +724,12 @@ def _add_expression( ) ), ) + cost_fun_expr = _expr_to_amici(cost_fun_expr) ode_model.add_component( LogLikelihoodY( - sp.Symbol(f"llh_{name}"), f"llh_{name}", cost_fun_expr + symbol_with_assumptions(f"llh_{name}"), + f"llh_{name}", + cost_fun_expr, ) ) @@ -743,10 +756,10 @@ def _get_sigma( _symbolic variable representing the standard deviation of the observable """ if sigma_name is None: - return sp.Symbol(f"sigma_{obs_name}") + return symbol_with_assumptions(f"sigma_{obs_name}") if sigma_name in pysb_model.expressions.keys(): - return pysb_model.expressions[sigma_name] + return _expr_to_amici(pysb_model.expressions[sigma_name]) raise ValueError(f"value of sigma {obs_name} is not a valid expression.") @@ -780,7 +793,6 @@ def _process_pysb_observables( # Observables as expressions for obs in pysb_model.observables: _add_expression( - obs, obs.name, obs.expand_obs(), pysb_model, @@ -1296,12 +1308,14 @@ def _construct_conservation_from_prototypes( for ix, specie in enumerate(pysb_model.species): count = extract_monomers(specie).count(monomer_name) if count > 0: - coefficients[sp.Symbol(f"__s{ix}")] = count + coefficients[symbol_with_assumptions(f"__s{ix}")] = count conservation_laws.append( { - "state": sp.Symbol(f"__s{target_index}"), - "total_abundance": sp.Symbol(f"tcl__s{target_index}"), + "state": symbol_with_assumptions(f"__s{target_index}"), + "total_abundance": symbol_with_assumptions( + f"tcl__s{target_index}" + ), "coefficients": coefficients, } ) @@ -1328,9 +1342,11 @@ def _add_conservation_for_constant_species( if ode_model.state_is_constant(ix): conservation_laws.append( { - "state": sp.Symbol(f"__s{ix}"), - "total_abundance": sp.Symbol(f"tcl__s{ix}"), - "coefficients": {sp.Symbol(f"__s{ix}"): 1.0}, + "state": symbol_with_assumptions(f"__s{ix}"), + "total_abundance": symbol_with_assumptions(f"tcl__s{ix}"), + "coefficients": { + symbol_with_assumptions(f"__s{ix}"): sp.Integer(1) + }, } ) diff --git a/python/sdist/amici/importers/utils.py b/python/sdist/amici/importers/utils.py index ff05260d81..274a0f1196 100644 --- a/python/sdist/amici/importers/utils.py +++ b/python/sdist/amici/importers/utils.py @@ -1014,3 +1014,17 @@ def contains_periodic_subexpression(expr: sp.Expr, symbol: sp.Symbol) -> bool: if sp.periodicity(subexpr, symbol) is not None: return True return False + + +def _expr_to_amici(expr: sp.Basic | sp.MatrixBase): + """Convert the given sympy expression to an AMICI-compatible expression. + + Replaces all symbols by plain sympy symbols with the expected assumptions. + + :param expr: The sympy expression to convert. + :return: The AMICI-compatible sympy expression. + """ + replacements = { + s: symbol_with_assumptions(s.name) for s in expr.free_symbols + } + return expr.xreplace(replacements) diff --git a/python/tests/test_pysb.py b/python/tests/test_pysb.py index 1ca7a4a83e..50e7aea6ea 100644 --- a/python/tests/test_pysb.py +++ b/python/tests/test_pysb.py @@ -19,7 +19,7 @@ from amici import MeasurementChannel, import_model_module from amici._symbolic.de_model_components import Event from amici.importers.pysb import pysb2amici -from amici.importers.utils import amici_time_symbol +from amici.importers.utils import amici_time_symbol, symbol_with_assumptions from amici.sim.sundials import ( ExpData, ParameterScaling, @@ -417,10 +417,15 @@ def test_pysb_event(tempdir): events = [ Event( # note that unlike for SBML import, we must omit the real=True here - symbol=sp.Symbol("event1"), + symbol=symbol_with_assumptions("event1"), name="Event1", value=amici_time_symbol - 5, - assignments={sp.Symbol("__s0"): sp.Symbol("__s0") + 1000}, + assignments={ + symbol_with_assumptions("__s0"): symbol_with_assumptions( + "__s0" + ) + + 1000 + }, use_values_from_trigger_time=False, ) ]