Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/sdist/amici/_symbolic/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,21 @@
class DEModel:
"""
Defines a Differential Equation as set of ModelQuantities.

This class provides general purpose interfaces to compute arbitrary
symbolic derivatives that are necessary for model simulation or
sensitivity computation.

All occurrences of a symbolic variable with a given name must use the same
assumptions (e.g. real, positive, etc.) throughout the model. Mixing
different assumptions for the same variable name will result in incorrect
derivatives and potentially other errors.

All symbols in the model are expected to be of type sympy.Symbol. If any
subtypes are used, they must behave identically to sympy.Symbol in all
relevant aspects (e.g. hashing, equality testing, etc.). In particular,
`str(symbol)` is expected to return the same value as `symbol.name`.

:ivar _differential_states:
differential state variables

Expand Down
90 changes: 53 additions & 37 deletions python/sdist/amici/importers/pysb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,19 @@
ObservableParameter,
SigmaY,
)
from amici.importers.utils import (
from amici.logging import get_logger, log_execution_time, set_log_level

from ..utils import (
MeasurementChannel,
_default_simplify,
_expr_to_amici,
_get_str_symbol_identifiers,
_parse_special_functions,
generate_measurement_symbol,
noise_distribution_to_cost_function,
noise_distribution_to_observable_transformation,
symbol_with_assumptions,
)
from amici.logging import get_logger, log_execution_time, set_log_level

CL_Prototype = dict[str, dict[str, Any]]
ConservationLaw = dict[str, dict | str | sp.Basic]
Expand Down Expand Up @@ -166,7 +169,7 @@ def pysb2amici(
model_name: str | None = None,
pysb_model_has_obs_and_noise: bool = False,
_events: list[Event] = None,
) -> amici.Model | None:
) -> amici.sim.sundials.Model | None:
r"""
Generate AMICI C++ files for the provided model.

Expand Down Expand Up @@ -264,7 +267,7 @@ def pysb2amici(
cache_simplify=cache_simplify,
verbose=verbose,
pysb_model_has_obs_and_noise=pysb_model_has_obs_and_noise,
events=_events,
_events=_events,
)

from amici.exporters.sundials.de_export import (
Expand Down Expand Up @@ -312,7 +315,7 @@ def ode_model_from_pysb_importer(
verbose: int | bool = False,
jax: bool = False,
pysb_model_has_obs_and_noise: bool = False,
events: list[Event] = None,
_events: list[Event] = None,
) -> DEModel:
"""
Creates an :class:`amici.DEModel` instance from a :class:`pysb.Model`
Expand Down Expand Up @@ -349,7 +352,7 @@ def ode_model_from_pysb_importer(
observables and noise variables added

:return:
New DEModel instance according to pysbModel
New DEModel instance according to `model`
"""

ode = DEModel(
Expand All @@ -370,7 +373,7 @@ def ode_model_from_pysb_importer(
_process_pysb_species(model, ode)
_process_pysb_parameters(model, ode, fixed_parameters, jax)
if compute_conservation_laws:
if events:
if _events:
raise NotImplementedError(
"Conservation law computation is not supported for models "
"with events. Use `compute_conservation_laws=False`."
Expand All @@ -389,7 +392,7 @@ def ode_model_from_pysb_importer(
pysb_model_has_obs_and_noise,
)

for event in events or []:
for event in _events or []:
ode.add_component(event)

ode._has_quadratic_nllh = all(
Expand Down Expand Up @@ -460,7 +463,9 @@ def get_cached_index(symbol, sarray, index_cache):
idx = get_cached_index(x_rdata[ix], w, wx_idx)
values = dflux_dw_dict

values[(ir, idx)] = sp.diff(rxn["rate"], x_rdata[ix])
values[(ir, idx)] = sp.diff(
_expr_to_amici(rxn["rate"]), x_rdata[ix]
)

# typically <= 3 free symbols in rate, we already account for
# species above so we only need to account for propensity, which
Expand All @@ -481,8 +486,10 @@ def get_cached_index(symbol, sarray, index_cache):
else:
continue

idx = get_cached_index(fs, var, idx_cache)
values[(ir, idx)] = sp.diff(rxn["rate"], fs)
idx = get_cached_index(
symbol_with_assumptions(fs.name), var, idx_cache
)
values[(ir, idx)] = _expr_to_amici(sp.diff(rxn["rate"], fs))

dflux_dx = sp.ImmutableSparseMatrix(n_r, n_x, dflux_dx_dict)
dflux_dw = sp.ImmutableSparseMatrix(n_r, n_w, dflux_dw_dict)
Expand Down Expand Up @@ -514,6 +521,7 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None:
DEModel instance
"""
xdot = sp.Matrix(pysb_model.odes)
xdot = _expr_to_amici(xdot)

for ix, specie in enumerate(pysb_model.species):
init = sp.sympify("0.0")
Expand All @@ -527,9 +535,13 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None:
else:
init = ic.value

init = _expr_to_amici(init)
ode_model.add_component(
DifferentialState(
sp.Symbol(f"__s{ix}"), f"{specie}", init, xdot[ix]
symbol_with_assumptions(f"__s{ix}"),
f"{specie}",
init,
xdot[ix],
)
)
logger.debug("Finished Processing PySB species ")
Expand Down Expand Up @@ -559,7 +571,7 @@ def _process_pysb_parameters(
DEModel instance
"""
for par in pysb_model.parameters:
args = [par, f"{par.name}"]
args = [symbol_with_assumptions(par.name), par.name]
if par.name in fixed_parameters:
comp = FixedParameter
args.append(par.value)
Expand Down Expand Up @@ -621,7 +633,6 @@ def _process_pysb_expressions(
# we skip them.
continue
_add_expression(
expr,
expr.name,
expr.expr,
pysb_model,
Expand All @@ -632,7 +643,6 @@ def _process_pysb_expressions(


def _add_expression(
sym: sp.Symbol,
name: str,
expr: sp.Basic,
pysb_model: pysb.Model,
Expand All @@ -644,9 +654,6 @@ def _add_expression(
Adds expressions to the ODE model given and adds observables/sigmas if
appropriate

:param sym:
symbol how the expression is referenced in the model

:param name:
name of the expression

Expand Down Expand Up @@ -674,22 +681,25 @@ def _add_expression(
component = SigmaY
else:
component = Expression
expr = _parse_special_functions(_expr_to_amici(expr))
ode_model.add_component(
component(sym, name, _parse_special_functions(expr))
component(symbol_with_assumptions(name), name, expr)
)

if name in observation_model:
noise_dist = observation_model[name].noise_distribution

y = sp.Symbol(name)
y = symbol_with_assumptions(name)
trafo = noise_distribution_to_observable_transformation(noise_dist)
# note that this is a bit iffy since we are potentially using the same _symbolic identifier in expressions (w)
# and observables (y). This is not a problem as there currently are no model functions that use both. If this
# changes, I would expect symbol redefinition warnings in CPP models and overwriting in JAX models, but as both
# symbols refer to the same _symbolic entity, this should not be a problem (untested)
obs = Observable(
y, name, _parse_special_functions(expr), transformation=trafo
)
# note that this is a bit iffy since we are potentially using the same
# _symbolic identifier in expressions (w) and observables (y).
# This is not a problem as there currently are no model functions that
# use both.
# If this changes, I would expect symbol redefinition warnings in CPP
# models and overwriting in JAX models, but as both symbols refer to
# the same _symbolic entity, this should not be a problem (untested)
expr = _parse_special_functions(_expr_to_amici(expr))
obs = Observable(y, name, expr, transformation=trafo)
ode_model.add_component(obs)

sigma_name = observation_model[name].sigma
Expand All @@ -714,9 +724,12 @@ def _add_expression(
)
),
)
cost_fun_expr = _expr_to_amici(cost_fun_expr)
ode_model.add_component(
LogLikelihoodY(
sp.Symbol(f"llh_{name}"), f"llh_{name}", cost_fun_expr
symbol_with_assumptions(f"llh_{name}"),
f"llh_{name}",
cost_fun_expr,
)
)

Expand All @@ -743,10 +756,10 @@ def _get_sigma(
_symbolic variable representing the standard deviation of the observable
"""
if sigma_name is None:
return sp.Symbol(f"sigma_{obs_name}")
return symbol_with_assumptions(f"sigma_{obs_name}")

if sigma_name in pysb_model.expressions.keys():
return pysb_model.expressions[sigma_name]
return _expr_to_amici(pysb_model.expressions[sigma_name])

raise ValueError(f"value of sigma {obs_name} is not a valid expression.")

Expand Down Expand Up @@ -780,7 +793,6 @@ def _process_pysb_observables(
# Observables as expressions
for obs in pysb_model.observables:
_add_expression(
obs,
obs.name,
obs.expand_obs(),
pysb_model,
Expand Down Expand Up @@ -1296,12 +1308,14 @@ def _construct_conservation_from_prototypes(
for ix, specie in enumerate(pysb_model.species):
count = extract_monomers(specie).count(monomer_name)
if count > 0:
coefficients[sp.Symbol(f"__s{ix}")] = count
coefficients[symbol_with_assumptions(f"__s{ix}")] = count

conservation_laws.append(
{
"state": sp.Symbol(f"__s{target_index}"),
"total_abundance": sp.Symbol(f"tcl__s{target_index}"),
"state": symbol_with_assumptions(f"__s{target_index}"),
"total_abundance": symbol_with_assumptions(
f"tcl__s{target_index}"
),
"coefficients": coefficients,
}
)
Expand All @@ -1328,9 +1342,11 @@ def _add_conservation_for_constant_species(
if ode_model.state_is_constant(ix):
conservation_laws.append(
{
"state": sp.Symbol(f"__s{ix}"),
"total_abundance": sp.Symbol(f"tcl__s{ix}"),
"coefficients": {sp.Symbol(f"__s{ix}"): 1.0},
"state": symbol_with_assumptions(f"__s{ix}"),
"total_abundance": symbol_with_assumptions(f"tcl__s{ix}"),
"coefficients": {
symbol_with_assumptions(f"__s{ix}"): sp.Integer(1)
},
}
)

Expand Down
14 changes: 14 additions & 0 deletions python/sdist/amici/importers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,17 @@ def contains_periodic_subexpression(expr: sp.Expr, symbol: sp.Symbol) -> bool:
if sp.periodicity(subexpr, symbol) is not None:
return True
return False


def _expr_to_amici(expr: sp.Basic | sp.MatrixBase):
"""Convert the given sympy expression to an AMICI-compatible expression.

Replaces all symbols by plain sympy symbols with the expected assumptions.

:param expr: The sympy expression to convert.
:return: The AMICI-compatible sympy expression.
"""
replacements = {
s: symbol_with_assumptions(s.name) for s in expr.free_symbols
}
return expr.xreplace(replacements)
11 changes: 8 additions & 3 deletions python/tests/test_pysb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from amici import MeasurementChannel, import_model_module
from amici._symbolic.de_model_components import Event
from amici.importers.pysb import pysb2amici
from amici.importers.utils import amici_time_symbol
from amici.importers.utils import amici_time_symbol, symbol_with_assumptions
from amici.sim.sundials import (
ExpData,
ParameterScaling,
Expand Down Expand Up @@ -417,10 +417,15 @@ def test_pysb_event(tempdir):
events = [
Event(
# note that unlike for SBML import, we must omit the real=True here
symbol=sp.Symbol("event1"),
symbol=symbol_with_assumptions("event1"),
name="Event1",
value=amici_time_symbol - 5,
assignments={sp.Symbol("__s0"): sp.Symbol("__s0") + 1000},
assignments={
symbol_with_assumptions("__s0"): symbol_with_assumptions(
"__s0"
)
+ 1000
},
use_values_from_trigger_time=False,
)
]
Expand Down
Loading