diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc34200..266b412 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: check-useless-excludes # - id: identity # Prints all files passed to pre-commits. Debugging. - repo: https://github.com/lyz-code/yamlfix - rev: 1.9.0 + rev: 1.17.0 hooks: - id: yamlfix - repo: local @@ -61,17 +61,17 @@ repos: rev: 1.13.0 hooks: - id: blacken-docs -# - repo: https://github.com/PyCQA/docformatter -# rev: v1.5.1 -# hooks: -# - id: docformatter -# args: -# - --in-place -# - --wrap-summaries -# - '88' -# - --wrap-descriptions -# - '88' -# - --blank + - repo: https://github.com/PyCQA/docformatter + rev: v1.7.7 + hooks: + - id: docformatter + args: + - --in-place + - --wrap-summaries + - '88' + - --wrap-descriptions + - '88' + - --blank - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.261 hooks: @@ -91,7 +91,7 @@ repos: - id: nbqa-black - id: nbqa-ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + rev: 0.7.22 hooks: - id: mdformat additional_dependencies: diff --git a/src/tranquilo/aggregate_models.py b/src/tranquilo/aggregate_models.py index d5ea4d6..adff6db 100644 --- a/src/tranquilo/aggregate_models.py +++ b/src/tranquilo/aggregate_models.py @@ -69,7 +69,7 @@ def aggregator_identity(vector_model): """ n_params = vector_model.linear_terms.size - intercept = float(vector_model.intercepts) + intercept = float(vector_model.intercepts[0]) linear_terms = vector_model.linear_terms.flatten() if vector_model.square_terms is None: square_terms = np.zeros((n_params, n_params)) diff --git a/src/tranquilo/batch_evaluators.py b/src/tranquilo/batch_evaluators.py new file mode 100644 index 0000000..a549c2a --- /dev/null +++ b/src/tranquilo/batch_evaluators.py @@ -0,0 +1,187 @@ +"""A collection of batch evaluators for process based parallelism. + +All batch evaluators have the same interface and any function with the same interface +can be used used as batch evaluator. + +""" + +from joblib import Parallel, delayed + +try: + from pathos.pools import ProcessPool + + pathos_is_available = True +except ImportError: + pathos_is_available = False + +from typing import Any, Callable, Literal, TypeVar + +from tranquilo.config import DEFAULT_N_CORES as N_CORES +from tranquilo.decorators import catch, unpack +from tranquilo.options import ErrorHandling + +T = TypeVar("T") + + +def pathos_mp_batch_evaluator( + func: Callable[..., T], + arguments: list[Any], + *, + n_cores: int = N_CORES, + error_handling: ( + ErrorHandling | Literal["raise", "continue"] + ) = ErrorHandling.CONTINUE, + unpack_symbol: Literal["*", "**"] | None = None, +) -> list[T]: + """Batch evaluator based on pathos.multiprocess.ProcessPool. + + This uses a patched but older version of python multiprocessing that replaces + pickling with dill and can thus handle decorated functions. + + Args: + func (Callable): The function that is evaluated. + arguments (Iterable): Arguments for the functions. Their interperation + depends on the unpack argument. + n_cores (int): Number of cores used to evaluate the function in parallel. + Value below one are interpreted as one. If only one core is used, the + batch evaluator disables everything that could cause problems, i.e. in that + case func and arguments are never pickled and func is executed in the main + process. + error_handling (str): Can take the values "raise" (raise the error and stop all + tasks as soon as one task fails) and "continue" (catch exceptions and set + the traceback of the raised exception. + KeyboardInterrupt and SystemExit are always raised. + unpack_symbol (str or None). Can be "**", "*" or None. If None, func just takes + one argument. If "*", the elements of arguments are positional arguments for + func. If "**", the elements of arguments are keyword arguments for func. + + + Returns: + list: The function evaluations. + + """ + if not pathos_is_available: + raise NotImplementedError( + "To use the pathos_mp_batch_evaluator, install pathos with " + "conda install -c conda-forge pathos." + ) + + _check_inputs(func, arguments, n_cores, error_handling, unpack_symbol) + n_cores = int(n_cores) + + reraise = error_handling in [ + "raise", + ErrorHandling.RAISE, + ErrorHandling.RAISE_STRICT, + ] + + @unpack(symbol=unpack_symbol) + @catch(default="__traceback__", reraise=reraise) + def internal_func(*args: Any, **kwargs: Any) -> T: + return func(*args, **kwargs) + + if n_cores <= 1: + res = [internal_func(arg) for arg in arguments] + else: + p = ProcessPool(nodes=n_cores) + try: + res = p.map(internal_func, arguments) + except Exception as e: + p.terminate() + raise e + + return res + + +def joblib_batch_evaluator( + func: Callable[..., T], + arguments: list[Any], + *, + n_cores: int = N_CORES, + error_handling: ( + ErrorHandling | Literal["raise", "continue"] + ) = ErrorHandling.CONTINUE, + unpack_symbol: Literal["*", "**"] | None = None, +) -> list[T]: + """Batch evaluator based on joblib's Parallel. + + Args: + func (Callable): The function that is evaluated. + arguments (Iterable): Arguments for the functions. Their interperation + depends on the unpack argument. + n_cores (int): Number of cores used to evaluate the function in parallel. + Value below one are interpreted as one. If only one core is used, the + batch evaluator disables everything that could cause problems, i.e. in that + case func and arguments are never pickled and func is executed in the main + process. + error_handling (str): Can take the values "raise" (raise the error and stop all + tasks as soon as one task fails) and "continue" (catch exceptions and set + the output of failed tasks to the traceback of the raised exception. + KeyboardInterrupt and SystemExit are always raised. + unpack_symbol (str or None). Can be "**", "*" or None. If None, func just takes + one argument. If "*", the elements of arguments are positional arguments for + func. If "**", the elements of arguments are keyword arguments for func. + + + Returns: + list: The function evaluations. + + """ + _check_inputs(func, arguments, n_cores, error_handling, unpack_symbol) + n_cores = int(n_cores) if int(n_cores) >= 2 else 1 + + reraise = error_handling in [ + "raise", + ErrorHandling.RAISE, + ErrorHandling.RAISE_STRICT, + ] + + @unpack(symbol=unpack_symbol) + @catch(default="__traceback__", reraise=reraise) + def internal_func(*args: Any, **kwargs: Any) -> T: + return func(*args, **kwargs) + + if n_cores == 1: + res = [internal_func(arg) for arg in arguments] + else: + res = Parallel(n_jobs=n_cores)(delayed(internal_func)(arg) for arg in arguments) + + return res + + +def _check_inputs( + func: Callable[..., T], + arguments: list[Any], + n_cores: int, + error_handling: ErrorHandling | Literal["raise", "continue"], + unpack_symbol: Literal["*", "**"] | None, +) -> None: + if not callable(func): + raise TypeError("func must be callable.") + + try: + arguments = list(arguments) + except Exception as e: + raise ValueError("arguments must be list like.") from e + + try: + int(n_cores) + except Exception as e: + raise ValueError("n_cores must be an integer.") from e + + if unpack_symbol not in (None, "*", "**"): + raise ValueError( + f"unpack_symbol must be None, '*' or '**', not {unpack_symbol}" + ) + + if error_handling not in [ + "raise", + "continue", + ErrorHandling.RAISE, + ErrorHandling.CONTINUE, + ErrorHandling.RAISE_STRICT, + ]: + raise ValueError( + "error_handling must be 'raise' or 'continue' or ErrorHandling not " + f"{error_handling}" + ) diff --git a/src/tranquilo/config.py b/src/tranquilo/config.py index 121780f..99dbfc6 100644 --- a/src/tranquilo/config.py +++ b/src/tranquilo/config.py @@ -18,6 +18,18 @@ CRITERION_PENALTY_CONSTANT = 100 +# ====================================================================================== +# Check Available Packages +# ====================================================================================== + +try: + import optimagic # noqa: F401 +except ImportError: + IS_OPTIMAGIC_INSTALLED = False +else: + IS_OPTIMAGIC_INSTALLED = True + + # ================================================================================= # Dashboard Defaults # ================================================================================= diff --git a/src/tranquilo/decorators.py b/src/tranquilo/decorators.py new file mode 100644 index 0000000..296ea1f --- /dev/null +++ b/src/tranquilo/decorators.py @@ -0,0 +1,126 @@ +"""This module contains various decorators. + +There are two kinds of decorators defined in this module which consists of either two or +three nested functions. The former are decorators without and the latter with arguments. + +For more information on decorators, see this `guide +`_ on https://realpython.com + +which +provides a comprehensive overview. + +.. _guide: + +https://realpython.com/primer-on-python-decorators/ + +""" + +import sys +from traceback import format_exception + +import functools +import warnings + + +def catch( + func=None, + *, + exception=Exception, + exclude=(KeyboardInterrupt, SystemExit), + onerror=None, + default=None, + warn=True, + reraise=False, +): + """Catch and handle exceptions. + + This decorator can be used with and without additional arguments. + + Args: + exception (Exception or tuple): One or several exceptions that + are caught and handled. By default all Exceptions are + caught and handled. + exclude (Exception or tuple): One or several exceptionts that + are not caught. By default those are KeyboardInterrupt and + SystemExit. + onerror (None or Callable): Callable that takes an Exception + as only argument. This is called when an exception occurs. + default: Value that is returned when as the output of func when + an exception occurs. Can be one of the following: + - a constant + - "__traceback__", in this case a string with a traceback is returned. + - callable with the same signature as func. + warn (bool): If True, the exception is converted to a warning. + reraise (bool): If True, the exception is raised after handling it. + + """ + + def decorator_catch(func): + @functools.wraps(func) + def wrapper_catch(*args, **kwargs): + try: + res = func(*args, **kwargs) + except exclude: + raise + except exception as e: + if onerror is not None: + onerror(e) + + if reraise: + raise e + + tb = get_traceback() + + if warn: + msg = f"The following exception was caught:\n\n{tb}" + warnings.warn(msg) + + if default == "__traceback__": + res = tb + elif callable(default): + res = default(*args, **kwargs) + else: + res = default + return res + + return wrapper_catch + + if callable(func): + return decorator_catch(func) + else: + return decorator_catch + + +def unpack(func=None, symbol=None): + def decorator_unpack(func): + if symbol is None: + + @functools.wraps(func) + def wrapper_unpack(arg): + return func(arg) + + elif symbol == "*": + + @functools.wraps(func) + def wrapper_unpack(arg): + return func(*arg) + + elif symbol == "**": + + @functools.wraps(func) + def wrapper_unpack(arg): + return func(**arg) + + return wrapper_unpack + + if callable(func): + return decorator_unpack(func) + else: + return decorator_unpack + + +def get_traceback(): + tb = format_exception(*sys.exc_info()) + if isinstance(tb, list): + tb = "".join(tb) + return tb diff --git a/src/tranquilo/options.py b/src/tranquilo/options.py index c83eced..5158dca 100644 --- a/src/tranquilo/options.py +++ b/src/tranquilo/options.py @@ -1,4 +1,5 @@ from typing import NamedTuple +from enum import Enum from tranquilo.models import n_free_params import numpy as np @@ -256,3 +257,11 @@ def update_option_bundle(default_options, user_options=None): out = default_options._replace(**typed) return out + + +class ErrorHandling(Enum): + """Enum to specify the error handling strategy of the optimization algorithm.""" + + RAISE = "raise" + RAISE_STRICT = "raise_strict" + CONTINUE = "continue" diff --git a/src/tranquilo/sample_points.py b/src/tranquilo/sample_points.py index 5ff0745..292b608 100644 --- a/src/tranquilo/sample_points.py +++ b/src/tranquilo/sample_points.py @@ -374,7 +374,7 @@ def _minimal_pairwise_distance_on_hull( x = _project_onto_unit_hull(x, trustregion_shape=trustregion_shape) if existing_xs is not None: - sample = np.row_stack([x, existing_xs]) + sample = np.vstack([x, existing_xs]) n_existing_pairs = len(existing_xs) * (len(existing_xs) - 1) // 2 slc = slice(0, -n_existing_pairs) if n_existing_pairs else slice(None) else: @@ -416,7 +416,7 @@ def _determinant_on_hull(x, existing_xs, trustregion_shape, n_params): x = _project_onto_unit_hull(x, trustregion_shape=trustregion_shape) if existing_xs is not None: - sample = np.row_stack([x, existing_xs]) + sample = np.vstack([x, existing_xs]) else: sample = x diff --git a/src/tranquilo/subsolvers/fallback_subsolvers.py b/src/tranquilo/subsolvers/fallback_subsolvers.py index 5192aad..109b58b 100644 --- a/src/tranquilo/subsolvers/fallback_subsolvers.py +++ b/src/tranquilo/subsolvers/fallback_subsolvers.py @@ -1,6 +1,6 @@ import numpy as np from functools import partial -from scipy.optimize import Bounds, NonlinearConstraint, minimize +from scipy.optimize import Bounds, minimize from tranquilo.exploration_sample import draw_exploration_sample @@ -80,8 +80,9 @@ def robust_cube_solver_multistart(model, x_candidate): def robust_sphere_solver_inscribed_cube(model, x_candidate): """Robust sphere solver that uses a cube solver in an inscribed cube. - We let x be in the largest cube that is inscribed inside the unit sphere. Formula - is taken from http://tinyurl.com/4astpuwn. + We let x be in the largest cube that is inscribed inside the unit sphere. Formula is + taken from + http://tinyurl.com/4astpuwn. This solver cannot find solutions on the hull of the sphere. @@ -184,16 +185,26 @@ def _grad(x, g, h): def _get_constraint(): - def _constr_fun(x): - return x @ x - - def _constr_jac(x): - return 2 * x - - return NonlinearConstraint( - fun=_constr_fun, - lb=-np.inf, - ub=1, - jac=_constr_jac, - keep_feasible=True, - ) + """Constraint enforcing ||x||^2 <= 1 as a simple inequality for SLSQP.""" + return { + "type": "ineq", + "fun": lambda x: 1 - x @ x, + "jac": lambda x: -2 * x, + } + + +# def _get_constraint(): +# """Raises scipy warning.""" +# def _constr_fun(x): +# return x @ x + +# def _constr_jac(x): +# return 2 * x + +# return NonlinearConstraint( +# fun=_constr_fun, +# lb=-np.inf, +# ub=1, +# jac=_constr_jac, +# keep_feasible=True, +# ) diff --git a/src/tranquilo/visualize.py b/src/tranquilo/visualize.py index ca2092b..a8db4f2 100644 --- a/src/tranquilo/visualize.py +++ b/src/tranquilo/visualize.py @@ -8,11 +8,12 @@ from plotly import graph_objects as go from plotly.subplots import make_subplots -from optimagic.optimization.optimize_result import OptimizeResult from tranquilo.clustering import cluster from tranquilo.geometry import log_d_quality_calculator from tranquilo.volume import get_radius_after_volume_scaling +from typing import Any, Protocol, runtime_checkable + def visualize_tranquilo(results, iterations): """Plot diagnostic information of optimization result in given iteration(s). @@ -56,7 +57,7 @@ def visualize_tranquilo(results, iterations): if isinstance(iterations, int): iterations = {case: iterations for case in results} results = {case: _process_results(results[case]) for case in results} - elif isinstance(results, OptimizeResult): + elif isinstance(results, OptimizeResultLike): results = _process_results(results) results = {f"iteration {i}": results for i in iterations} iterations = {f"iteration {iteration}": iteration for iteration in iterations} @@ -588,3 +589,13 @@ def _get_model_indices(xs, state): for point in state.model_points: model_indices = np.concatenate([model_indices, _find_index(xs, point)]) return model_indices.astype(int) + + +@runtime_checkable +class OptimizeResultLike(Protocol): + """Runtime-checkable stand-in for optimagic's OptimizeResult object.""" + + algorithm: str + history: Any + params: Any + algorithm_output: dict diff --git a/src/tranquilo/wrap_criterion.py b/src/tranquilo/wrap_criterion.py index e432eae..0622a77 100644 --- a/src/tranquilo/wrap_criterion.py +++ b/src/tranquilo/wrap_criterion.py @@ -80,9 +80,9 @@ def process_batch_evaluator(batch_evaluator="joblib"): out = batch_evaluator elif isinstance(batch_evaluator, str): if batch_evaluator == "joblib": - from optimagic.batch_evaluators import joblib_batch_evaluator as out + from tranquilo.batch_evaluators import joblib_batch_evaluator as out elif batch_evaluator == "pathos": - from optimagic.batch_evaluators import pathos_mp_batch_evaluator as out + from tranquilo.batch_evaluators import pathos_mp_batch_evaluator as out else: raise ValueError( "Invalid batch evaluator requested. Currently only 'pathos' and " diff --git a/tests/subsolvers/test_gqtpar_lambdas.py b/tests/subsolvers/test_gqtpar_lambdas.py index 0c1a770..5be7840 100644 --- a/tests/subsolvers/test_gqtpar_lambdas.py +++ b/tests/subsolvers/test_gqtpar_lambdas.py @@ -1,11 +1,16 @@ -from optimagic.optimization.optimize import minimize -from optimagic.benchmarking.get_benchmark_problems import get_benchmark_problems +import pytest +from tranquilo.config import IS_OPTIMAGIC_INSTALLED +if IS_OPTIMAGIC_INSTALLED: + from optimagic.optimization.optimize import minimize + from optimagic.benchmarking.get_benchmark_problems import get_benchmark_problems + +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") def test_gqtpar_lambdas(): algo_options = { "disable_convergence": True, - "stopping_max_iterations": 30, + "stopping_maxiter": 30, "sample_filter": "keep_all", "sampler": "random_hull", "subsolver_options": {"k_hard": 0.001, "k_easy": 0.001}, @@ -13,7 +18,7 @@ def test_gqtpar_lambdas(): problem_info = get_benchmark_problems("more_wild")["freudenstein_roth_good_start"] minimize( - criterion=problem_info["inputs"]["fun"], + fun=problem_info["inputs"]["fun"], params=problem_info["inputs"]["params"], algo_options=algo_options, algorithm="tranquilo", diff --git a/tests/test_fit_models.py b/tests/test_fit_models.py index 84b5324..fa2887c 100644 --- a/tests/test_fit_models.py +++ b/tests/test_fit_models.py @@ -1,9 +1,16 @@ import numpy as np import pytest -from optimagic.differentiation.derivatives import first_derivative, second_derivative +from numpy.testing import assert_array_almost_equal, assert_array_equal + from tranquilo.fit_models import _quadratic_features, get_fitter from tranquilo.region import Region -from numpy.testing import assert_array_almost_equal, assert_array_equal +from tranquilo.config import IS_OPTIMAGIC_INSTALLED + +if IS_OPTIMAGIC_INSTALLED: + from optimagic.differentiation.derivatives import ( + first_derivative, + second_derivative, + ) def aaae(x, y, decimal=None, case=None): @@ -91,6 +98,7 @@ def test_fit_against_truth_quadratic(fitter, quadratic_case): ) +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") @pytest.mark.parametrize("model", ["ols", "ridge", "tranquilo"]) def test_fit_ols_against_gradient(model, quadratic_case): options = {"l2_penalty_square": 0} @@ -113,9 +121,10 @@ def test_fit_ols_against_gradient(model, quadratic_case): grad = a + hess @ quadratic_case["x0"] gradient = first_derivative(quadratic_case["func"], quadratic_case["x0"]) - aaae(gradient["derivative"], grad, case="gradient") + aaae(gradient.derivative, grad, case="gradient") +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") @pytest.mark.parametrize("model", ("ols", "ridge", "tranquilo", "powell")) def test_fit_ols_against_hessian(model, quadratic_case): options = {"l2_penalty_square": 0} @@ -134,7 +143,7 @@ def test_fit_ols_against_hessian(model, quadratic_case): ) hessian = second_derivative(quadratic_case["func"], quadratic_case["x0"]) hess = got.square_terms.reshape((4, 4)) - aaae(hessian["derivative"], hess, case="hessian") + aaae(hessian.derivative, hess, case="hessian") def test_quadratic_features(): diff --git a/tests/test_tranquilo.py b/tests/test_tranquilo.py index b966132..1882eea 100644 --- a/tests/test_tranquilo.py +++ b/tests/test_tranquilo.py @@ -1,12 +1,16 @@ import itertools -import numpy as np import pytest -from optimagic.optimization.optimize import minimize -from tranquilo.tranquilo import _tranquilo from functools import partial +import numpy as np from numpy.testing import assert_array_almost_equal as aaae -from optimagic import mark + +from tranquilo.tranquilo import _tranquilo +from tranquilo.config import IS_OPTIMAGIC_INSTALLED + +if IS_OPTIMAGIC_INSTALLED: + from optimagic.optimization.optimize import minimize + from optimagic import mark tranquilo = partial( @@ -119,9 +123,10 @@ def test_internal_tranquilo_scalar_sphere_imprecise_defaults( # ====================================================================================== +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") def test_external_tranquilo_scalar_sphere_defaults(): res = minimize( - criterion=lambda x: x @ x, + fun=lambda x: x @ x, params=np.arange(4), algorithm="tranquilo", ) @@ -172,9 +177,10 @@ def test_internal_tranquilo_ls_sphere_defaults( # ====================================================================================== +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") def test_external_tranquilo_ls_sphere_defaults(): res = minimize( - criterion=mark.least_squares(lambda x: x), + fun=mark.least_squares(lambda x: x), params=np.arange(5), algorithm="tranquilo_ls", ) @@ -187,31 +193,37 @@ def test_external_tranquilo_ls_sphere_defaults(): # ====================================================================================== -@pytest.mark.parametrize("algo", ["tranquilo", "tranquilo_ls"]) -def test_tranquilo_with_noise_handling_and_deterministic_function(algo): - def _f(x): - return {"root_contributions": x, "value": x @ x} - +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") +@pytest.mark.parametrize( + "algorithm, criterion", + [ + ("tranquilo", mark.scalar(lambda x: x @ x)), + ("tranquilo_ls", mark.least_squares(lambda x: x)), + ], +) +def test_tranquilo_with_noise_handling_and_deterministic_function(algorithm, criterion): res = minimize( - criterion=_f, + fun=criterion, params=np.arange(5), - algorithm=algo, + algorithm=algorithm, algo_options={"noisy": True}, ) aaae(res.params, np.zeros(5), decimal=3) +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") @pytest.mark.slow() def test_tranquilo_ls_with_noise_handling_and_noisy_function(): rng = np.random.default_rng(123) + @mark.least_squares def _f(x): x_n = x + rng.normal(0, 0.05, size=x.shape) - return {"root_contributions": x_n, "value": x_n @ x_n} + return x_n res = minimize( - criterion=_f, + fun=_f, params=np.ones(3), algorithm="tranquilo_ls", algo_options={"noisy": True, "n_evals_per_point": 10}, @@ -225,18 +237,19 @@ def _f(x): # ====================================================================================== -def sum_of_squares(x): - contribs = x**2 - return {"value": contribs.sum(), "contributions": contribs, "root_contributions": x} - - -@pytest.mark.parametrize("algorithm", ["tranquilo", "tranquilo_ls"]) -def test_tranquilo_with_binding_bounds(algorithm): +@pytest.mark.skipif(not IS_OPTIMAGIC_INSTALLED, reason="optimagic is not installed.") +@pytest.mark.parametrize( + "algorithm, criterion", + [ + ("tranquilo", mark.scalar(lambda x: x @ x)), + ("tranquilo_ls", mark.least_squares(lambda x: x)), + ], +) +def test_tranquilo_with_binding_bounds(algorithm, criterion): res = minimize( - criterion=sum_of_squares, + fun=criterion, params=np.array([3, 2, -3]), - lower_bounds=np.array([1, -np.inf, -np.inf]), - upper_bounds=np.array([np.inf, np.inf, -1]), + bounds=[(1, np.inf), (-np.inf, np.inf), (-np.inf, -1)], algorithm=algorithm, collect_history=True, skip_checks=True, diff --git a/tests/test_visualize.py b/tests/test_visualize.py index 6882164..67084e5 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -1,8 +1,10 @@ import pytest -from optimagic.optimization.optimize import minimize -from optimagic.benchmarking.get_benchmark_problems import get_benchmark_problems from tranquilo.visualize import visualize_tranquilo +from tranquilo.config import IS_OPTIMAGIC_INSTALLED +if IS_OPTIMAGIC_INSTALLED: + from optimagic.optimization.optimize import minimize + from optimagic.benchmarking.get_benchmark_problems import get_benchmark_problems cases = [] algo_options = { @@ -10,13 +12,13 @@ "sampler": "random_hull", "sphere_subsolver": "gqtpar_fast", "sample_filter": "keep_all", - "stopping_max_iterations": 10, + "stopping_maxiter": 10, }, "optimal_hull": { "sampler": "optimal_hull", "sphere_subsolver": "gqtpar_fast", "sample_filter": "keep_all", - "stopping_max_iterations": 10, + "stopping_maxiter": 10, }, } for problem in ["rosenbrock_good_start", "watson_6_good_start"]: @@ -27,7 +29,7 @@ results = {} for s, options in algo_options.items(): results[s] = minimize( - criterion=fun, + fun=fun, params=start_params, algo_options=options, algorithm=algorithm,