diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index b5b4ed809c..67d38a3a7a 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -71,12 +71,15 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str: def _print_min_max(self, expr, cpp_fun: str, sympy_fun): # C++ doesn't like mixing int and double for arguments for min/max, # therefore, we just always convert to float - arg0 = ( - sp.Float(expr.args[0]) if expr.args[0].is_number else expr.args[0] - ) + args = [ + self._print(sp.Float(arg) if arg.is_number else arg) + for arg in expr.args + ] if len(expr.args) == 1: - return self._print(arg0) - return f"{self._ns}{cpp_fun}({self._print(arg0)}, {self._print(sympy_fun(*expr.args[1:]))})" + return args[0] + if len(expr.args) == 2: + return f"{self._ns}{cpp_fun}({args[0]}, {args[1]})" + return f"{self._ns}{cpp_fun}({get_initializer_list(args)})" def _print_Min(self, expr): from sympy.functions.elementary.miscellaneous import Min diff --git a/python/tests/test_cxxcodeprinter.py b/python/tests/test_cxxcodeprinter.py index 8eec3beb7f..e01fb74c1a 100644 --- a/python/tests/test_cxxcodeprinter.py +++ b/python/tests/test_cxxcodeprinter.py @@ -49,3 +49,16 @@ def test_print_infinity(): ) assert cp.doprint(sp.zoo) == "std::numeric_limits::infinity()" assert cp.doprint(-sp.zoo) == "std::numeric_limits::infinity()" + + +@skip_on_valgrind +def test_min_max(): + """Check that AmiciCxxCodePrinter prints min() and max() correctly.""" + a, b, c = sp.symbols("a b c") + cp = AmiciCxxCodePrinter() + assert cp.doprint(sp.Min(a)) == "a" + assert cp.doprint(sp.Max(a)) == "a" + assert cp.doprint(sp.Min(a, b)) == "std::min(a, b)" + assert cp.doprint(sp.Max(a, b)) == "std::max(a, b)" + assert cp.doprint(sp.Min(a, b, c)) == "std::min({a, b, c})" + assert cp.doprint(sp.Max(a, b, c)) == "std::max({a, b, c})"