Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str:
def _print_min_max(self, expr, cpp_fun: str, sympy_fun):
# C++ doesn't like mixing int and double for arguments for min/max,
# therefore, we just always convert to float
arg0 = (
sp.Float(expr.args[0]) if expr.args[0].is_number else expr.args[0]
)
args = [
self._print(sp.Float(arg) if arg.is_number else arg)
for arg in expr.args
]
if len(expr.args) == 1:
return self._print(arg0)
return f"{self._ns}{cpp_fun}({self._print(arg0)}, {self._print(sympy_fun(*expr.args[1:]))})"
return args[0]
if len(expr.args) == 2:
return f"{self._ns}{cpp_fun}({args[0]}, {args[1]})"
return f"{self._ns}{cpp_fun}({get_initializer_list(args)})"

def _print_Min(self, expr):
from sympy.functions.elementary.miscellaneous import Min
Expand Down
13 changes: 13 additions & 0 deletions python/tests/test_cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,16 @@ def test_print_infinity():
)
assert cp.doprint(sp.zoo) == "std::numeric_limits<double>::infinity()"
assert cp.doprint(-sp.zoo) == "std::numeric_limits<double>::infinity()"


@skip_on_valgrind
def test_min_max():
"""Check that AmiciCxxCodePrinter prints min() and max() correctly."""
a, b, c = sp.symbols("a b c")
cp = AmiciCxxCodePrinter()
assert cp.doprint(sp.Min(a)) == "a"
assert cp.doprint(sp.Max(a)) == "a"
assert cp.doprint(sp.Min(a, b)) == "std::min(a, b)"
assert cp.doprint(sp.Max(a, b)) == "std::max(a, b)"
assert cp.doprint(sp.Min(a, b, c)) == "std::min({a, b, c})"
assert cp.doprint(sp.Max(a, b, c)) == "std::max({a, b, c})"
Loading