Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/tranquilo/acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_acceptance_decider(
"naive_noisy": accept_naive_noisy,
"noisy": accept_noisy,
"classic_line_search": accept_classic_line_search,
"greedy": accept_greedy,
}

out = get_component(
Expand All @@ -38,6 +39,82 @@ def get_acceptance_decider(
return out


def accept_greedy(
subproblem_solution,
state,
history,
*,
wrapped_criterion,
min_improvement,
):
"""Do a greedy acceptance step for a trustregion algorithm.

Args:
subproblem_solution (SubproblemResult): Result of the subproblem solution.
state (State): Namedtuple containing the trustregion, criterion value of
previously accepted point, indices of model points, etc.
wrapped_criterion (callable): The criterion function.
min_improvement (float): Minimum improvement required to accept a point.

Returns:
AcceptanceResult

"""
candidate_x = subproblem_solution.x
candidate_index = history.add_xs(candidate_x)
wrapped_criterion({candidate_index: 1})

candidate_fval = np.mean(history.get_fvals(candidate_index))
candidate_improvement = -(candidate_fval - state.fval)

rho = calculate_rho(
actual_improvement=candidate_improvement,
expected_improvement=subproblem_solution.expected_improvement,
)

best_x, best_fval, best_index = history.get_best()

assert np.isfinite(best_fval)
assert isinstance(best_x, np.ndarray)
assert isinstance(best_index, int)
assert isinstance(best_fval, float)
assert best_x.ndim == 1
assert np.mean(history.get_fvals(best_index)) == best_fval

if best_fval < candidate_fval and best_fval < state.fval:
candidate_x = best_x
candidate_fval = best_fval
candidate_index = best_index
overall_improvement = -(best_fval - state.fval)
else:
overall_improvement = candidate_improvement

is_accepted = overall_improvement >= min_improvement

if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=1,
)
else:
res = _get_acceptance_result(
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=1,
)

return res


def _accept_classic(
subproblem_solution,
state,
Expand Down
17 changes: 17 additions & 0 deletions src/tranquilo/history.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd


class History:
Expand Down Expand Up @@ -173,6 +174,22 @@ def get_fvals(self, x_indices):
)
return out

def get_best(self):
"""Retrieve best fval and corresponding x and index.

If there are multiple evaluations per x, the best fval is computed as the mean
of all evaluations per x.

Returns:
tuple: (x, fval, index), the current minimizer x, the corresponding fval,
which is a mean if there are multiple evaluations per x, and the index.

"""
fvals = self.get_fvals(np.arange(self.n_xs))
average_fvals = {key: np.mean(val) for key, val in fvals.items()}
index = int(pd.Series(average_fvals).idxmin())
return self.get_xs(index), average_fvals[index], index

def get_n_evals(self, x_indices):
fvals = self.get_fvals(x_indices)
n_evals = {k: len(v) for k, v in fvals.items()}
Expand Down
54 changes: 54 additions & 0 deletions tests/test_acceptance_decision.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections import namedtuple

from functools import partial
import numpy as np
import pytest
from tranquilo.sample_points import get_sampler
from tranquilo.acceptance_decision import (
accept_greedy,
_accept_simple,
_get_acceptance_result,
calculate_rho,
Expand Down Expand Up @@ -82,6 +84,58 @@ def wrapped_criterion(eval_info):
assert_array_equal(res_got.candidate_x, 1.0 + np.arange(2))


# ======================================================================================
# Test accept_greedy
# ======================================================================================


@pytest.mark.parametrize("state", states)
def test_accept_greedy(
state,
subproblem_solution,
):
"""Test accept greedy.

Tests that the best point is chosen in the acceptance step, even though it is added
to the history before the acceptance step.

"""
history = History(functype="scalar")

def criterion(x):
return np.sum(x**2)

def _wrapped_criterion(eval_info, history):
for x_index, _ in eval_info.items():
xs = history.get_xs(x_index)
crit_value = criterion(xs)
history.add_evals(np.array([x_index]), crit_value)

wrapped_criterion = partial(_wrapped_criterion, history=history)

# Add existing xs to history and evaluate wrapped criterion
existing_xs = np.zeros((1, 2))
existing_xs_indices = history.add_xs(existing_xs)

eval_info = {x_index: 1 for x_index in existing_xs_indices}
wrapped_criterion(eval_info)

res_got = accept_greedy(
subproblem_solution=subproblem_solution,
state=state,
history=history,
wrapped_criterion=wrapped_criterion,
min_improvement=0.0,
)

assert res_got.accepted
assert res_got.index == 0
assert res_got.candidate_index == 0
assert res_got.fval == 0.0
assert_array_equal(res_got.x, np.zeros(2))
assert_array_equal(res_got.candidate_x, np.zeros(2))


# ======================================================================================
# Test _get_acceptance_result
# ======================================================================================
Expand Down
8 changes: 8 additions & 0 deletions tests/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,11 @@ def test_get_model_data_with_repeated_evaluations(noisy_history, average):
else:
aaae(got_xs, np.arange(6).reshape(2, 3).repeat([2, 3], axis=0))
aaae(got_fvecs, np.arange(25).reshape(5, 5))


def test_get_best(noisy_history):
x, fval, index = noisy_history.get_best()
assert index == 0
assert isinstance(index, int)
assert fval == 142.5
aaae(x, np.array([0, 1, 2]))