Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from sympy.utilities.iterables import numbered_symbols
from toposort import toposort

from .import_utils import symbol_with_assumptions


class AmiciCxxCodePrinter(CXX11CodePrinter):
"""
Expand Down Expand Up @@ -322,7 +324,7 @@ def csc_matrix(
colnames: list[sp.Symbol],
identifier: int | None = 0,
pattern_only: bool | None = False,
) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]:
) -> tuple[list[int], list[int], sp.Matrix, list[sp.Symbol], sp.Matrix]:
"""
Generates the sparse symbolic identifiers, symbolic identifiers,
sparse matrix, column pointers and row values for a symbolic
Expand Down Expand Up @@ -371,11 +373,11 @@ def csc_matrix(
symbol_name = f"d{rownames[row].name}_d{colnames[col].name}"
if identifier:
symbol_name += f"_{identifier}"
symbol_list.append(symbol_name)
symbol_list.append(symbol_with_assumptions(symbol_name))
if pattern_only:
continue

sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True)
sparse_matrix[row, col] = symbol_with_assumptions(symbol_name)
sparse_list.append(matrix[row, col])

if idx == 0:
Expand Down
11 changes: 6 additions & 5 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _write_index_files(self, name: str) -> None:
lines = []
for index, symbol in enumerate(symbols):
symbol_name = strip_pysb(symbol)
if str(symbol) == "0":
if symbol.is_zero:
continue
if str(symbol_name) == "":
raise ValueError(f'{name} contains a symbol called ""')
Expand Down Expand Up @@ -816,10 +816,11 @@ def _get_function_body(
function in self.model.sym_names()
and function not in non_unique_id_symbols
):
if function in sparse_functions:
symbols = list(map(sp.Symbol, self.model.sparsesym(function)))
else:
symbols = self.model.sym(function)
symbols = (
self.model.sparsesym(function)
if function in sparse_functions
else self.model.sym(function)
)

if function in ("w", "dwdw", "dwdx", "dwdp"):
# Split into a block of static and dynamic expressions.
Expand Down
24 changes: 14 additions & 10 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class DEModel:
model

:ivar _sparsesyms:
carries linear list of all symbolic identifiers for sparsified
carries linear list of all symbols for sparsified
variables

:ivar _colptrs:
Expand Down Expand Up @@ -193,7 +193,7 @@ class DEModel:
res and FIM make sense.

:ivar _static_indices:
dict of lists list of indices of static variables for different
dict of lists of indices of static variables for different
model entities.

:ivar _z2event:
Expand Down Expand Up @@ -254,7 +254,9 @@ def __init__(
self._vals: dict[str, list[sp.Expr]] = dict()
self._names: dict[str, list[str]] = dict()
self._syms: dict[str, sp.Matrix | list[sp.Matrix]] = dict()
self._sparsesyms: dict[str, list[str] | list[list[str]]] = dict()
self._sparsesyms: dict[
str, list[sp.Symbol] | list[list[sp.Symbol]]
] = dict()
self._colptrs: dict[str, list[int] | list[list[int]]] = dict()
self._rowvals: dict[str, list[int] | list[list[int]]] = dict()

Expand Down Expand Up @@ -616,11 +618,11 @@ def add_conservation_law(
variables.

:param state:
symbolic identifier of the state that should be replaced by
Symbol of the state that should be replaced by
the conservation law (:math:`x_j`)

:param total_abundance:
symbolic identifier of the total abundance (:math:`T/a_j`)
Symbol of the total abundance (:math:`T/a_j`)

:param coefficients:
Dictionary of coefficients {x_i: a_i}
Expand Down Expand Up @@ -814,16 +816,18 @@ def sym(self, name: str) -> sp.Matrix:
name of the symbolic variable

:return:
matrix of symbolic identifiers
matrix of symbols
"""
if name not in self._syms:
self._generate_symbol(name)

return self._syms[name]

def sparsesym(self, name: str, force_generate: bool = True) -> list[str]:
def sparsesym(
self, name: str, force_generate: bool = True
) -> list[sp.Symbol]:
"""
Returns (and constructs if necessary) the sparsified identifiers for
Returns (and constructs if necessary) the sparsified symbols for
a sparsified symbolic variable.

:param name:
Expand All @@ -833,7 +837,7 @@ def sparsesym(self, name: str, force_generate: bool = True) -> list[str]:
whether the symbols should be generated if not available

:return:
linearized Matrix containing the symbolic identifiers
linearized Matrix containing the symbols
"""
if name not in sparse_functions:
raise ValueError(f"{name} is not marked as sparse")
Expand Down Expand Up @@ -1101,7 +1105,7 @@ def dynamic_indices(self, name: str) -> list[int]:

def _generate_symbol(self, name: str) -> None:
"""
Generates the symbolic identifiers for a symbolic variable
Generates the symbols for a symbolic variable

:param name:
name of the symbolic variable
Expand Down
6 changes: 4 additions & 2 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
"LogLikelihoodZ",
"LogLikelihoodRZ",
"ModelQuantity",
"NoiseParameter",
"Observable",
"ObservableParameter",
"Parameter",
"SigmaY",
"SigmaZ",
Expand Down Expand Up @@ -176,7 +178,7 @@ def get_ncoeff(self, state_id) -> sp.Expr | int | float:
:param state_id:
identifier of the state

:return: normalized coefficent of the state
:return: normalized coefficient of the state
"""
return self._coefficients.get(state_id, 0.0) / self._ncoeff

Expand Down Expand Up @@ -404,7 +406,7 @@ def __init__(
value: sp.Expr,
measurement_symbol: sp.Symbol | None = None,
transformation: None
| (ObservableTransformation) = ObservableTransformation.LIN,
| ObservableTransformation = ObservableTransformation.LIN,
):
"""
Create a new Observable instance.
Expand Down
14 changes: 11 additions & 3 deletions python/tests/test_ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sympy as sp
from amici.cxxcodeprinter import csc_matrix
from amici.testing import skip_on_valgrind
from amici.import_utils import symbol_with_assumptions


@skip_on_valgrind
Expand All @@ -24,7 +25,11 @@ def test_csc_matrix():
assert symbol_col_ptrs == [0, 2, 3]
assert symbol_row_vals == [0, 1, 1]
assert sparse_list == sp.Matrix([[1], [2], [3]])
assert symbol_list == ["da1_db1", "da2_db1", "da2_db2"]
assert symbol_list == [
symbol_with_assumptions("da1_db1"),
symbol_with_assumptions("da2_db1"),
symbol_with_assumptions("da2_db2"),
]
assert str(sparse_matrix) == "Matrix([[da1_db1, 0], [da2_db1, da2_db2]])"


Expand Down Expand Up @@ -66,7 +71,10 @@ def test_csc_matrix_vector():
assert symbol_col_ptrs == [0, 2]
assert symbol_row_vals == [0, 1]
assert sparse_list == sp.Matrix([[1], [2]])
assert symbol_list == ["da1_db", "da2_db"]
assert symbol_list == [
symbol_with_assumptions("da1_db"),
symbol_with_assumptions("da2_db"),
]
assert str(sparse_matrix) == "Matrix([[da1_db], [da2_db]])"

# Test continuation of numbering of symbols
Expand All @@ -86,7 +94,7 @@ def test_csc_matrix_vector():
assert symbol_col_ptrs == [0, 1]
assert symbol_row_vals == [1]
assert sparse_list == sp.Matrix([[3]])
assert symbol_list == ["da2_db_1"]
assert symbol_list == [symbol_with_assumptions("da2_db_1")]
assert str(sparse_matrix) == "Matrix([[0], [da2_db_1]])"


Expand Down
Loading