diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index 67d38a3a7a..2a5cb4b4a2 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -13,6 +13,8 @@ from sympy.utilities.iterables import numbered_symbols from toposort import toposort +from .import_utils import symbol_with_assumptions + class AmiciCxxCodePrinter(CXX11CodePrinter): """ @@ -322,7 +324,7 @@ def csc_matrix( colnames: list[sp.Symbol], identifier: int | None = 0, pattern_only: bool | None = False, -) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]: +) -> tuple[list[int], list[int], sp.Matrix, list[sp.Symbol], sp.Matrix]: """ Generates the sparse symbolic identifiers, symbolic identifiers, sparse matrix, column pointers and row values for a symbolic @@ -371,11 +373,11 @@ def csc_matrix( symbol_name = f"d{rownames[row].name}_d{colnames[col].name}" if identifier: symbol_name += f"_{identifier}" - symbol_list.append(symbol_name) + symbol_list.append(symbol_with_assumptions(symbol_name)) if pattern_only: continue - sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True) + sparse_matrix[row, col] = symbol_with_assumptions(symbol_name) sparse_list.append(matrix[row, col]) if idx == 0: diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 7c6de98d9f..a8e8764956 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -399,7 +399,7 @@ def _write_index_files(self, name: str) -> None: lines = [] for index, symbol in enumerate(symbols): symbol_name = strip_pysb(symbol) - if str(symbol) == "0": + if symbol.is_zero: continue if str(symbol_name) == "": raise ValueError(f'{name} contains a symbol called ""') @@ -816,10 +816,11 @@ def _get_function_body( function in self.model.sym_names() and function not in non_unique_id_symbols ): - if function in sparse_functions: - symbols = list(map(sp.Symbol, self.model.sparsesym(function))) - else: - symbols = self.model.sym(function) + symbols = ( + self.model.sparsesym(function) + if function in sparse_functions + else self.model.sym(function) + ) if function in ("w", "dwdw", "dwdx", "dwdp"): # Split into a block of static and dynamic expressions. diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 829b99fe56..4aaead9770 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -143,7 +143,7 @@ class DEModel: model :ivar _sparsesyms: - carries linear list of all symbolic identifiers for sparsified + carries linear list of all symbols for sparsified variables :ivar _colptrs: @@ -193,7 +193,7 @@ class DEModel: res and FIM make sense. :ivar _static_indices: - dict of lists list of indices of static variables for different + dict of lists of indices of static variables for different model entities. :ivar _z2event: @@ -254,7 +254,9 @@ def __init__( self._vals: dict[str, list[sp.Expr]] = dict() self._names: dict[str, list[str]] = dict() self._syms: dict[str, sp.Matrix | list[sp.Matrix]] = dict() - self._sparsesyms: dict[str, list[str] | list[list[str]]] = dict() + self._sparsesyms: dict[ + str, list[sp.Symbol] | list[list[sp.Symbol]] + ] = dict() self._colptrs: dict[str, list[int] | list[list[int]]] = dict() self._rowvals: dict[str, list[int] | list[list[int]]] = dict() @@ -616,11 +618,11 @@ def add_conservation_law( variables. :param state: - symbolic identifier of the state that should be replaced by + Symbol of the state that should be replaced by the conservation law (:math:`x_j`) :param total_abundance: - symbolic identifier of the total abundance (:math:`T/a_j`) + Symbol of the total abundance (:math:`T/a_j`) :param coefficients: Dictionary of coefficients {x_i: a_i} @@ -814,16 +816,18 @@ def sym(self, name: str) -> sp.Matrix: name of the symbolic variable :return: - matrix of symbolic identifiers + matrix of symbols """ if name not in self._syms: self._generate_symbol(name) return self._syms[name] - def sparsesym(self, name: str, force_generate: bool = True) -> list[str]: + def sparsesym( + self, name: str, force_generate: bool = True + ) -> list[sp.Symbol]: """ - Returns (and constructs if necessary) the sparsified identifiers for + Returns (and constructs if necessary) the sparsified symbols for a sparsified symbolic variable. :param name: @@ -833,7 +837,7 @@ def sparsesym(self, name: str, force_generate: bool = True) -> list[str]: whether the symbols should be generated if not available :return: - linearized Matrix containing the symbolic identifiers + linearized Matrix containing the symbols """ if name not in sparse_functions: raise ValueError(f"{name} is not marked as sparse") @@ -1101,7 +1105,7 @@ def dynamic_indices(self, name: str) -> list[int]: def _generate_symbol(self, name: str) -> None: """ - Generates the symbolic identifiers for a symbolic variable + Generates the symbols for a symbolic variable :param name: name of the symbolic variable diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index 4f12ec70be..a269338274 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -26,7 +26,9 @@ "LogLikelihoodZ", "LogLikelihoodRZ", "ModelQuantity", + "NoiseParameter", "Observable", + "ObservableParameter", "Parameter", "SigmaY", "SigmaZ", @@ -176,7 +178,7 @@ def get_ncoeff(self, state_id) -> sp.Expr | int | float: :param state_id: identifier of the state - :return: normalized coefficent of the state + :return: normalized coefficient of the state """ return self._coefficients.get(state_id, 0.0) / self._ncoeff @@ -404,7 +406,7 @@ def __init__( value: sp.Expr, measurement_symbol: sp.Symbol | None = None, transformation: None - | (ObservableTransformation) = ObservableTransformation.LIN, + | ObservableTransformation = ObservableTransformation.LIN, ): """ Create a new Observable instance. diff --git a/python/tests/test_ode_export.py b/python/tests/test_ode_export.py index 6dc4d85f05..9e6917718e 100644 --- a/python/tests/test_ode_export.py +++ b/python/tests/test_ode_export.py @@ -3,6 +3,7 @@ import sympy as sp from amici.cxxcodeprinter import csc_matrix from amici.testing import skip_on_valgrind +from amici.import_utils import symbol_with_assumptions @skip_on_valgrind @@ -24,7 +25,11 @@ def test_csc_matrix(): assert symbol_col_ptrs == [0, 2, 3] assert symbol_row_vals == [0, 1, 1] assert sparse_list == sp.Matrix([[1], [2], [3]]) - assert symbol_list == ["da1_db1", "da2_db1", "da2_db2"] + assert symbol_list == [ + symbol_with_assumptions("da1_db1"), + symbol_with_assumptions("da2_db1"), + symbol_with_assumptions("da2_db2"), + ] assert str(sparse_matrix) == "Matrix([[da1_db1, 0], [da2_db1, da2_db2]])" @@ -66,7 +71,10 @@ def test_csc_matrix_vector(): assert symbol_col_ptrs == [0, 2] assert symbol_row_vals == [0, 1] assert sparse_list == sp.Matrix([[1], [2]]) - assert symbol_list == ["da1_db", "da2_db"] + assert symbol_list == [ + symbol_with_assumptions("da1_db"), + symbol_with_assumptions("da2_db"), + ] assert str(sparse_matrix) == "Matrix([[da1_db], [da2_db]])" # Test continuation of numbering of symbols @@ -86,7 +94,7 @@ def test_csc_matrix_vector(): assert symbol_col_ptrs == [0, 1] assert symbol_row_vals == [1] assert sparse_list == sp.Matrix([[3]]) - assert symbol_list == ["da2_db_1"] + assert symbol_list == [symbol_with_assumptions("da2_db_1")] assert str(sparse_matrix) == "Matrix([[0], [da2_db_1]])"