diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 7d8508a658..95f05be858 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -648,7 +648,7 @@ def add_conservation_law( try: ix = next( filter( - lambda is_s: is_s[1].get_id() == state, + lambda is_s: is_s[1].get_sym() == state, enumerate(self._differential_states), ) )[0] @@ -657,7 +657,7 @@ def add_conservation_law( f"Specified state {state} was not found in the model states." ) - state_id = self._differential_states[ix].get_id() + state_id = self._differential_states[ix].get_sym() # \sum_{i≠j}(a_i * x_i)/a_j target_expression = ( @@ -704,7 +704,7 @@ def add_spline(self, spline: AbstractSpline, spline_expr: sp.Expr) -> None: self._splines.append(spline) self.add_component( Expression( - identifier=spline.sbml_id, + symbol=spline.sbml_id, name=str(spline.sbml_id), value=spline_expr, ) @@ -1136,7 +1136,7 @@ def _generate_symbol(self, name: str) -> None: components = sorted( components, key=lambda x: int( - str(strip_pysb(x.get_id())).replace( + str(strip_pysb(x.get_sym())).replace( "observableParameter", "" ) ), @@ -1145,13 +1145,13 @@ def _generate_symbol(self, name: str) -> None: components = sorted( components, key=lambda x: int( - str(strip_pysb(x.get_id())).replace( + str(strip_pysb(x.get_sym())).replace( "noiseParameter", "" ) ), ) self._syms[name] = sp.Matrix( - [comp.get_id() for comp in components] + [comp.get_sym() for comp in components] ) if name == "y": self._syms["my"] = sp.Matrix( @@ -1168,7 +1168,7 @@ def _generate_symbol(self, name: str) -> None: elif name == "x": self._syms[name] = sp.Matrix( [ - state.get_id() + state.get_sym() for state in self.states() if not state.has_conservation_law() ] @@ -1214,8 +1214,8 @@ def _generate_symbol(self, name: str) -> None: [ [ sp.Symbol( - f"s{strip_pysb(tcl.get_id())}__" - f"{strip_pysb(par.get_id())}", + f"s{strip_pysb(tcl.get_sym())}__" + f"{strip_pysb(par.get_sym())}", real=True, ) for par in self._parameters @@ -1312,7 +1312,7 @@ def parse_events(self) -> None: w_toposorted = toposort_symbols( dict( zip( - [expr.get_id() for expr in self._expressions], + [expr.get_sym() for expr in self._expressions], [expr.get_val() for expr in self._expressions], strict=True, ) @@ -1393,9 +1393,11 @@ def get_appearance_counts(self, idxs: list[int]) -> list[int]: ) return [ - free_symbols_dt.count(str(self._differential_states[idx].get_id())) + free_symbols_dt.count( + str(self._differential_states[idx].get_sym()) + ) + free_symbols_expr.count( - str(self._differential_states[idx].get_id()) + str(self._differential_states[idx].get_sym()) ) for idx in idxs ] @@ -1528,7 +1530,7 @@ def _compute_equation(self, name: str) -> None: elif name == "x_solver": self._eqs[name] = sp.Matrix( [ - state.get_id() + state.get_sym() for state in self.states() if not state.has_conservation_law() ] @@ -1704,7 +1706,7 @@ def _compute_equation(self, name: str) -> None: event_observables = [ sp.zeros(self.num_eventobs(), 1) for _ in self._events ] - event_ids = [e.get_id() for e in self._events] + event_ids = [e.get_sym() for e in self._events] z2event = [ event_ids.index(event_obs.get_event()) for event_obs in self._event_observables @@ -2285,7 +2287,7 @@ def get_conservation_laws(self) -> list[tuple[sp.Symbol, sp.Expr]]: list of state identifiers """ return [ - (state.get_id(), state.get_x_rdata()) + (state.get_sym(), state.get_x_rdata()) for state in self.states() if state.has_conservation_law() ] @@ -2338,7 +2340,7 @@ def state_has_fixed_parameter_initial_condition(self, ix: int) -> bool: if not isinstance(ic, sp.Basic): return False return any( - fp in (c.get_id() for c in self._constants) + fp in (c.get_sym() for c in self._constants) for fp in ic.free_symbols ) @@ -2450,20 +2452,20 @@ def _get_unique_root( for root in roots: if sp.simplify(root_found - root.get_val()).is_zero: - return root.get_id() + return root.get_sym() # create an event for a new root function root_symstr = f"Heaviside_{len(roots)}" roots.append( Event( - identifier=sp.Symbol(root_symstr), + symbol=sp.Symbol(root_symstr), name=root_symstr, value=root_found, assignments=None, use_values_from_trigger_time=True, ) ) - return roots[-1].get_id() + return roots[-1].get_sym() def _collect_heaviside_roots( self, @@ -2579,22 +2581,22 @@ def _process_hybridization(self, hybridization: dict) -> None: https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file """ added_expressions = False - orig_obs = tuple([s.get_id() for s in self._observables]) + orig_obs = tuple([s.get_sym() for s in self._observables]) for net_id, net in hybridization.items(): if net["static"]: continue # do not integrate into ODEs, handle in amici.jax.petab inputs = [ comp for comp in self._components - if str(comp.get_id()) in net["input_vars"] + if str(comp.get_sym()) in net["input_vars"] ] # sort inputs by order in input_vars inputs = sorted( inputs, - key=lambda comp: net["input_vars"].index(str(comp.get_id())), + key=lambda comp: net["input_vars"].index(str(comp.get_sym())), ) if len(inputs) != len(net["input_vars"]): - found_vars = {str(comp.get_id()) for comp in inputs} + found_vars = {str(comp.get_sym()) for comp in inputs} missing_vars = set(net["input_vars"]) - found_vars raise ValueError( f"Could not find all input variables for neural network {net_id}. " @@ -2616,9 +2618,9 @@ def _process_hybridization(self, hybridization: dict) -> None: outputs = { out_var: {"comp": comp, "ind": net["output_vars"][out_var]} for comp in self._components - if (out_var := str(comp.get_id())) in net["output_vars"] + if (out_var := str(comp.get_sym())) in net["output_vars"] # TODO: SYNTAX NEEDS to CHANGE - or (out_var := str(comp.get_id()) + "_dot") + or (out_var := str(comp.get_sym()) + "_dot") in net["output_vars"] } if len(outputs.keys()) != len(net["output_vars"]): @@ -2645,7 +2647,7 @@ def _process_hybridization(self, hybridization: dict) -> None: # generate dummy Function out_val = sp.Function(net_id)( - *[input.get_id() for input in inputs], parts["ind"] + *[input.get_sym() for input in inputs], parts["ind"] ) # add to the model @@ -2659,7 +2661,7 @@ def _process_hybridization(self, hybridization: dict) -> None: else: self.add_component( Expression( - identifier=comp.get_id(), + symbol=comp.get_sym(), name=net_id, value=out_val, ) @@ -2669,7 +2671,7 @@ def _process_hybridization(self, hybridization: dict) -> None: observables = { ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]} for comp in self._components - if (ob_var := str(comp.get_id())) in net["observable_vars"] + if (ob_var := str(comp.get_sym())) in net["observable_vars"] # # TODO: SYNTAX NEEDS to CHANGE # or (ob_var := str(comp.get_id()) + "_dot") # in net["observable_vars"] @@ -2691,18 +2693,18 @@ def _process_hybridization(self, hybridization: dict) -> None: f"{comp.get_name()} ({type(comp)}) is not an observable." ) out_val = sp.Function(net_id)( - *[input.get_id() for input in inputs], parts["ind"] + *[input.get_sym() for input in inputs], parts["ind"] ) # add to the model self.add_component( Observable( - identifier=comp.get_id(), + symbol=comp.get_sym(), name=net_id, value=out_val, ) ) - new_order = [orig_obs.index(s.get_id()) for s in self._observables] + new_order = [orig_obs.index(s.get_sym()) for s in self._observables] self._observables = [self._observables[i] for i in new_order] if added_expressions: diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index 0922772bc3..0641fe9211 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -47,15 +47,15 @@ class ModelQuantity: def __init__( self, - identifier: sp.Symbol, + symbol: sp.Symbol, name: str, value: SupportsFloat | numbers.Number | sp.Expr, ): """ Create a new ModelQuantity instance. - :param identifier: - unique identifier of the quantity + :param symbol: + Symbol of the quantity with unique identifier. :param name: individual name of the quantity (does not need to be unique) @@ -64,18 +64,17 @@ def __init__( either formula, numeric value or initial value """ - if not isinstance(identifier, sp.Symbol): - raise TypeError( - f"identifier must be sympy.Symbol, was {type(identifier)}" - ) + if not isinstance(symbol, sp.Symbol): + raise TypeError(f"symbol must be sympy.Symbol, was {type(symbol)}") - if str(identifier) in RESERVED_SYMBOLS or ( - hasattr(identifier, "name") and identifier.name in RESERVED_SYMBOLS + if str(symbol) in RESERVED_SYMBOLS or ( + hasattr(symbol, "name") and symbol.name in RESERVED_SYMBOLS ): raise ValueError( - f'Cannot add model quantity with name "{name}", please rename.' + f'Cannot add model quantity with reserved name "{name}", ' + "please rename." ) - self._identifier: sp.Symbol = identifier + self._symbol: sp.Symbol = symbol if not isinstance(name, str): raise TypeError(f"name must be str, was {type(name)}") @@ -91,16 +90,29 @@ def __repr__(self) -> str: :return: string representation of the ModelQuantity """ - return str(self._identifier) + return str(self._symbol) + + def get_sym(self) -> sp.Symbol: + """ + ModelQuantity symbol + + :return: + Symbol of the ModelQuantity + """ + return self._symbol - def get_id(self) -> sp.Symbol: + def get_id(self) -> str: """ ModelQuantity identifier :return: identifier of the ModelQuantity """ - return self._identifier + return ( + self._symbol.name + if hasattr(self._symbol, "name") + else str(self._symbol) + ) def get_name(self) -> str: """ @@ -139,7 +151,7 @@ class ConservationLaw(ModelQuantity): def __init__( self, - identifier: sp.Symbol, + symbol: sp.Symbol, name: str, value: sp.Expr, coefficients: dict[sp.Symbol, sp.Expr], @@ -148,8 +160,8 @@ def __init__( """ Create a new ConservationLaw instance. - :param identifier: - unique identifier of the ConservationLaw + :param symbol: + unique symbol of the ConservationLaw :param name: individual name of the ConservationLaw (does not need to be @@ -161,26 +173,26 @@ def __init__( coefficients of the states in the sum :param state_id: - identifier of the state that this conservation law replaces + Symbol of the state that this conservation law replaces """ - self._state_expr: sp.Symbol = identifier - (value - state_id) + self._state_expr: sp.Expr = symbol - (value - state_id) self._coefficients: dict[sp.Symbol, sp.Expr] = coefficients self._ncoeff: sp.Expr = coefficients[state_id] - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) - def get_ncoeff(self, state_id) -> sp.Expr | int | float: + def get_ncoeff(self, state_sym: sp.Symbol) -> sp.Expr | int | float: """ Computes the normalized coefficient a_i/a_j where i is the index of the provided state_id and j is the index of the state that is replaced by this conservation law. This can be used to compute both dtotal_cl/dx_rdata (=ncoeff) and dx_rdata/dx_solver (=-ncoeff). - :param state_id: - identifier of the state + :param state_sym: + Symbol of the state :return: normalized coefficient of the state """ - return self._coefficients.get(state_id, 0.0) / self._ncoeff + return self._coefficients.get(state_sym, 0.0) / self._ncoeff def get_x_rdata(self): """ @@ -197,7 +209,7 @@ class AlgebraicEquation(ModelQuantity): An AlgebraicEquation defines an algebraic equation. """ - def __init__(self, identifier: str, value: sp.Expr): + def __init__(self, symbol: sp.Symbol, value: sp.Expr): """ Create a new AlgebraicEquation instance. @@ -205,7 +217,7 @@ def __init__(self, identifier: str, value: sp.Expr): Formula of the algebraic equation, the solution is given by ``formula == 0`` """ - super().__init__(sp.Symbol(identifier), identifier, value) + super().__init__(symbol, symbol.name, value) def get_free_symbols(self): return self._value.free_symbols @@ -229,7 +241,7 @@ def get_x_rdata(self): :return: x_rdata expression """ if self._conservation_law is None: - return self.get_id() + return self.get_sym() else: return self._conservation_law.get_x_rdata() @@ -242,7 +254,7 @@ def get_dx_rdata_dx_solver(self, state_id): :return: dx_rdata_dx_solver expression """ if self._conservation_law is None: - return sp.Integer(self._identifier == state_id) + return sp.Integer(self._symbol == state_id) else: return -self._conservation_law.get_ncoeff(state_id) @@ -261,12 +273,12 @@ class AlgebraicState(State): An AlgebraicState defines an entity that is algebraically determined """ - def __init__(self, identifier: sp.Symbol, name: str, init: sp.Expr): + def __init__(self, symbol: sp.Symbol, name: str, init: sp.Expr): """ Create a new AlgebraicState instance. - :param identifier: - unique identifier of the AlgebraicState + :param symbol: + unique symbol of the AlgebraicState :param name: individual name of the AlgebraicState (does not need to be unique) @@ -274,9 +286,9 @@ def __init__(self, identifier: sp.Symbol, name: str, init: sp.Expr): :param init: initial value of the AlgebraicState """ - super().__init__(identifier, name, init) + super().__init__(symbol, name, init) - def has_conservation_law(self): + def has_conservation_law(self) -> bool: """ Checks whether this state has a conservation law assigned. @@ -288,7 +300,7 @@ def get_free_symbols(self): return self._value.free_symbols def get_x_rdata(self): - return self._identifier + return self._symbol class DifferentialState(State): @@ -306,14 +318,14 @@ class DifferentialState(State): """ def __init__( - self, identifier: sp.Symbol, name: str, init: sp.Expr, dt: sp.Expr + self, symbol: sp.Symbol, name: str, init: sp.Expr, dt: sp.Expr ): """ Create a new State instance. Extends :meth:`ModelQuantity.__init__` by ``dt`` - :param identifier: - unique identifier of the state + :param symbol: + unique symbol of the state :param name: individual name of the state (does not need to be unique) @@ -324,7 +336,7 @@ def __init__( :param dt: time derivative """ - super().__init__(identifier, name, init) + super().__init__(symbol, name, init) self._dt = cast_to_sym(dt, "dt") self._conservation_law: ConservationLaw | None = None @@ -401,7 +413,7 @@ class Observable(ModelQuantity): def __init__( self, - identifier: sp.Symbol, + symbol: sp.Symbol, name: str, value: sp.Expr, measurement_symbol: sp.Symbol | None = None, @@ -411,8 +423,8 @@ def __init__( """ Create a new Observable instance. - :param identifier: - unique identifier of the Observable + :param symbol: + unique symbol of the Observable :param name: individual name of the Observable (does not need to be unique) @@ -424,7 +436,7 @@ def __init__( observable transformation, only applies when evaluating objective function or residuals """ - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) self._measurement_symbol = measurement_symbol self._regularization_symbol = None self.trafo = transformation @@ -432,7 +444,7 @@ def __init__( def get_measurement_symbol(self) -> sp.Symbol: if self._measurement_symbol is None: self._measurement_symbol = generate_measurement_symbol( - self.get_id() + self.get_sym() ) return self._measurement_symbol @@ -440,7 +452,7 @@ def get_measurement_symbol(self) -> sp.Symbol: def get_regularization_symbol(self) -> sp.Symbol: if self._regularization_symbol is None: self._regularization_symbol = generate_regularization_symbol( - self.get_id() + self.get_sym() ) return self._regularization_symbol @@ -457,7 +469,7 @@ class EventObservable(Observable): def __init__( self, - identifier: sp.Symbol, + symbol: sp.Symbol, name: str, value: sp.Expr, event: sp.Symbol, @@ -467,7 +479,7 @@ def __init__( """ Create a new EventObservable instance. - :param identifier: + :param symbol: See :py:meth:`Observable.__init__`. :param name: @@ -483,7 +495,7 @@ def __init__( Symbolic identifier of the corresponding event. """ super().__init__( - identifier, name, value, measurement_symbol, transformation + symbol, name, value, measurement_symbol, transformation ) self._event: sp.Symbol = event @@ -503,12 +515,12 @@ class Sigma(ModelQuantity): abbreviated by ``sigma{y,z}``. """ - def __init__(self, identifier: sp.Symbol, name: str, value: sp.Expr): + def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr): """ Create a new Standard Deviation instance. - :param identifier: - unique identifier of the Standard Deviation + :param symbol: + unique symbol of the Standard Deviation :param name: individual name of the Standard Deviation (does not need to @@ -521,7 +533,7 @@ def __init__(self, identifier: sp.Symbol, name: str, value: sp.Expr): raise RuntimeError( "This class is meant to be sub-classed, not used directly." ) - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) class SigmaY(Sigma): @@ -544,12 +556,12 @@ class Expression(ModelQuantity): Abbreviated by ``w``. """ - def __init__(self, identifier: sp.Symbol, name: str, value: sp.Expr): + def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr): """ Create a new Expression instance. - :param identifier: - unique identifier of the Expression + :param symbol: + unique symbol of the Expression :param name: individual name of the Expression (does not need to be unique) @@ -557,7 +569,7 @@ def __init__(self, identifier: sp.Symbol, name: str, value: sp.Expr): :param value: formula """ - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) class Parameter(ModelQuantity): @@ -566,14 +578,12 @@ class Parameter(ModelQuantity): sensitivities may be computed, abbreviated by ``p``. """ - def __init__( - self, identifier: sp.Symbol, name: str, value: numbers.Number - ): + def __init__(self, symbol: sp.Symbol, name: str, value: numbers.Number): """ Create a new Expression instance. - :param identifier: - unique identifier of the Parameter + :param symbol: + unique symbol of the Parameter :param name: individual name of the Parameter (does not need to be @@ -582,7 +592,7 @@ def __init__( :param value: numeric value """ - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) class Constant(ModelQuantity): @@ -591,14 +601,12 @@ class Constant(ModelQuantity): sensitivities cannot be computed, abbreviated by ``k``. """ - def __init__( - self, identifier: sp.Symbol, name: str, value: numbers.Number - ): + def __init__(self, symbol: sp.Symbol, name: str, value: numbers.Number): """ Create a new Expression instance. - :param identifier: - unique identifier of the Constant + :param symbol: + unique symbol of the Constant :param name: individual name of the Constant (does not need to be unique) @@ -606,7 +614,7 @@ def __init__( :param value: numeric value """ - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) class NoiseParameter(ModelQuantity): @@ -615,18 +623,18 @@ class NoiseParameter(ModelQuantity): specific manner, abbreviated by ``np``. Only used for jax models. """ - def __init__(self, identifier: sp.Symbol, name: str): + def __init__(self, symbol: sp.Symbol, name: str): """ Create a new Expression instance. - :param identifier: - unique identifier of the NoiseParameter + :param symbol: + unique symbol of the NoiseParameter :param name: individual name of the NoiseParameter (does not need to be unique) """ - super().__init__(identifier, name, 0.0) + super().__init__(symbol, name, 0.0) class ObservableParameter(ModelQuantity): @@ -635,18 +643,18 @@ class ObservableParameter(ModelQuantity): manner, abbreviated by ``op``. Only used for jax models. """ - def __init__(self, identifier: sp.Symbol, name: str): + def __init__(self, symbol: sp.Symbol, name: str): """ Create a new Expression instance. - :param identifier: - unique identifier of the ObservableParameter + :param symbol: + unique symbol of the ObservableParameter :param name: individual name of the ObservableParameter (does not need to be unique) """ - super().__init__(identifier, name, 0.0) + super().__init__(symbol, name, 0.0) class LogLikelihood(ModelQuantity): @@ -657,12 +665,12 @@ class LogLikelihood(ModelQuantity): instances evaluated at all timepoints, abbreviated by ``Jy``. """ - def __init__(self, identifier: sp.Symbol, name: str, value: sp.Expr): + def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr): """ Create a new Expression instance. - :param identifier: - unique identifier of the LogLikelihood + :param symbol: + unique symbol of the LogLikelihood :param name: individual name of the LogLikelihood (does not need to be @@ -675,7 +683,7 @@ def __init__(self, identifier: sp.Symbol, name: str, value: sp.Expr): raise RuntimeError( "This class is meant to be sub-classed, not used directly." ) - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) class LogLikelihoodY(LogLikelihood): @@ -707,7 +715,7 @@ class Event(ModelQuantity): def __init__( self, - identifier: sp.Symbol, + symbol: sp.Symbol, name: str, value: sp.Expr, use_values_from_trigger_time: bool, @@ -718,8 +726,8 @@ def __init__( """ Create a new Event instance. - :param identifier: - unique identifier of the Event + :param symbol: + unique symbol of the Event :param name: individual name of the Event (does not need to be unique) @@ -741,7 +749,7 @@ def __init__( the time point at which the event triggered (True), or at the time point at which the event assignment is evaluated (False). """ - super().__init__(identifier, name, value) + super().__init__(symbol, name, value) # add the Event specific components self._assignments = assignments if assignments is not None else {} self._initial_value = initial_value diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index 9f55cfc60a..e48faedd1d 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -691,7 +691,7 @@ def _add_expression( ) cost_fun_str = noise_distribution_to_cost_function(noise_dist)(name) - my = generate_measurement_symbol(obs.get_id()) + my = generate_measurement_symbol(obs.get_sym()) cost_fun_expr = sp.sympify( cost_fun_str, locals=dict( diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 4824f38de7..7f248eb95d 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -632,7 +632,7 @@ def _build_ode_model( # create all basic components of the DE model and add them. for symbol_name in self.symbols: # transform dict of lists into a list of dicts - args = ["name", "identifier"] + args = ["name", "symbol"] if symbol_name == SymbolId.SPECIES: args += ["dt", "init"] @@ -655,7 +655,7 @@ def _build_ode_model( comp_kwargs = [ { - "identifier": var_id, + "symbol": var_id, **{k: v for k, v in var.items() if k in args}, } for var_id, var in self.symbols[symbol_name].items() @@ -672,7 +672,7 @@ def _build_ode_model( # replace splines inside fluxes flux = flux.subs(spline_subs) ode_model.add_component( - Expression(identifier=flux_id, name=str(flux_id), value=flux) + Expression(symbol=flux_id, name=str(flux_id), value=flux) ) if compute_conservation_laws: @@ -1506,7 +1506,9 @@ def _process_rule_algebraic(self, rule: libsbml.AlgebraicRule): assert len(free_variables) >= 1 self.symbols[SymbolId.ALGEBRAIC_EQUATION][ - f"ae{len(self.symbols[SymbolId.ALGEBRAIC_EQUATION])}" + symbol_with_assumptions( + f"ae{len(self.symbols[SymbolId.ALGEBRAIC_EQUATION])}" + ) ] = {"value": formula} # remove the symbol from the original definition and add to # algebraic symbols (if not already done) @@ -2371,9 +2373,7 @@ def _get_conservation_laws_demartino( stoichiometric_list, *sm.shape, rng_seed=32, - species_names=[ - str(x.get_id()) for x in ode_model._differential_states - ], + species_names=[x.get_id() for x in ode_model._differential_states], ) # Sparsify conserved quantities @@ -2502,7 +2502,7 @@ def _add_conservation_for_non_constant_species( # previously removed constant species eliminated_state_ids = {cl["state"] for cl in conservation_laws} - all_state_ids = [x.get_id() for x in model.states()] + all_state_ids = [x.get_sym() for x in model.states()] all_compartment_sizes = [] for state_id in all_state_ids: symbol = { @@ -3166,7 +3166,7 @@ def _add_conservation_for_constant_species( if ode_model.state_is_constant(ix): # dont use sym('x') here since conservation laws need to be # added before symbols are generated - target_state = ode_model._differential_states[ix].get_id() + target_state = ode_model._differential_states[ix].get_sym() total_abundance = symbol_with_assumptions(f"tcl_{target_state}") conservation_laws.append( { diff --git a/python/tests/test_de_model.py b/python/tests/test_de_model.py index 1c253f4c67..5ffe245d27 100644 --- a/python/tests/test_de_model.py +++ b/python/tests/test_de_model.py @@ -7,7 +7,7 @@ @skip_on_valgrind def test_event_trigger_time(): e = Event( - identifier=sp.Symbol("event1"), + symbol=sp.Symbol("event1"), name="event name", value=amici_time_symbol - 10, assignments=sp.Float(1), @@ -18,7 +18,7 @@ def test_event_trigger_time(): # fixed, but multiple timepoints - not (yet) supported e = Event( - identifier=sp.Symbol("event1"), + symbol=sp.Symbol("event1"), name="event name", value=sp.sin(amici_time_symbol), assignments=sp.Float(1), @@ -27,7 +27,7 @@ def test_event_trigger_time(): assert e.triggers_at_fixed_timepoint() is False e = Event( - identifier=sp.Symbol("event1"), + symbol=sp.Symbol("event1"), name="event name", value=amici_time_symbol / 2, assignments=sp.Float(1), @@ -38,7 +38,7 @@ def test_event_trigger_time(): # parameter-dependent triggers - not (yet) supported e = Event( - identifier=sp.Symbol("event1"), + symbol=sp.Symbol("event1"), name="event name", value=amici_time_symbol - sp.Symbol("delay"), assignments=sp.Float(1), diff --git a/python/tests/test_pysb.py b/python/tests/test_pysb.py index 4803494550..fef1d4110e 100644 --- a/python/tests/test_pysb.py +++ b/python/tests/test_pysb.py @@ -409,7 +409,7 @@ def test_pysb_event(tempdir): events = [ Event( # note that unlike for SBML import, we must omit the real=True here - identifier=sp.Symbol("event1"), + symbol=sp.Symbol("event1"), name="Event1", value=amici_time_symbol - 5, assignments={sp.Symbol("__s0"): sp.Symbol("__s0") + 1000}, diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index e5626e5398..4a51df9024 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -1181,9 +1181,9 @@ def test_time_dependent_initial_assignment(compute_conservation_laws: bool): # "species", because differential state assert symbol_with_assumptions("x1") in si.symbols[SymbolId.SPECIES].keys() - assert "p0" in [str(p.get_id()) for p in de_model.parameters()] - assert "p1" not in [str(p.get_id()) for p in de_model.parameters()] - assert "p2" not in [str(p.get_id()) for p in de_model.parameters()] + assert "p0" in [p.get_id() for p in de_model.parameters()] + assert "p1" not in [p.get_id() for p in de_model.parameters()] + assert "p2" not in [p.get_id() for p in de_model.parameters()] assert list(de_model.sym("x_rdata")) == [ symbol_with_assumptions("p2"),