Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions bayes_opt/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def get_acquisition_params(self) -> dict[str, Any]:
)
raise NotImplementedError(error_msg)

def set_acquisition_params(self, **params) -> None:
def set_acquisition_params(self, params: dict[str, Any]) -> None:
"""
Set the parameters of the acquisition function.

Parameters
----------
**params : dict
params : dict
The parameters of the acquisition function.
"""
error_msg = (
Expand Down Expand Up @@ -553,7 +553,7 @@ def decay_exploration(self) -> None:
):
self.kappa = self.kappa * self.exploration_decay

def get_acquisition_params(self) -> dict:
def get_acquisition_params(self) -> dict[str, Any]:
"""Get the current acquisition function parameters.

Returns
Expand All @@ -567,7 +567,7 @@ def get_acquisition_params(self) -> dict:
"exploration_decay_delay": self.exploration_decay_delay,
}

def set_acquisition_params(self, params: dict) -> None:
def set_acquisition_params(self, params: dict[str, Any]) -> None:
"""Set the acquisition function parameters.

Parameters
Expand Down Expand Up @@ -733,7 +733,7 @@ def decay_exploration(self) -> None:
):
self.xi = self.xi * self.exploration_decay

def get_acquisition_params(self) -> dict:
def get_acquisition_params(self) -> dict[str, Any]:
"""Get the current acquisition function parameters.

Returns
Expand All @@ -747,7 +747,7 @@ def get_acquisition_params(self) -> dict:
"exploration_decay_delay": self.exploration_decay_delay,
}

def set_acquisition_params(self, params: dict) -> None:
def set_acquisition_params(self, params: dict[str, Any]) -> None:
"""Set the acquisition function parameters.

Parameters
Expand Down Expand Up @@ -922,7 +922,7 @@ def decay_exploration(self) -> None:
):
self.xi = self.xi * self.exploration_decay

def get_acquisition_params(self) -> dict:
def get_acquisition_params(self) -> dict[str, Any]:
"""Get the current acquisition function parameters.

Returns
Expand All @@ -936,7 +936,7 @@ def get_acquisition_params(self) -> dict:
"exploration_decay_delay": self.exploration_decay_delay,
}

def set_acquisition_params(self, params: dict) -> None:
def set_acquisition_params(self, params: dict[str, Any]) -> None:
"""Set the acquisition function parameters.

Parameters
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def suggest(

return x_max

def get_acquisition_params(self) -> dict:
def get_acquisition_params(self) -> dict[str, Any]:
"""Get the current acquisition function parameters.

Returns
Expand All @@ -1163,7 +1163,7 @@ def get_acquisition_params(self) -> dict:
"rtol": self.rtol,
}

def set_acquisition_params(self, params: dict) -> None:
def set_acquisition_params(self, params: dict[str, Any]) -> None:
"""Set the acquisition function parameters.

Parameters
Expand Down Expand Up @@ -1318,7 +1318,7 @@ def suggest(
idx = self._sample_idx_from_softmax_gains(random_state=random_state)
return x_max[idx]

def get_acquisition_params(self) -> dict:
def get_acquisition_params(self) -> dict[str, Any]:
"""Get the current acquisition function parameters.

Returns
Expand All @@ -1334,7 +1334,7 @@ def get_acquisition_params(self) -> dict:
else None,
}

def set_acquisition_params(self, params: dict) -> None:
def set_acquisition_params(self, params: dict[str, Any]) -> None:
"""Set the acquisition function parameters.

Parameters
Expand Down
14 changes: 7 additions & 7 deletions bayes_opt/bayesian_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def probe(self, params: ParamsType, lazy: bool = True) -> None:
self._space.keys, self._space.res()[-1], self._space.params_config, self.max
)

def random_sample(self, n: int = 1) -> dict[str, float | NDArray[Float]]:
def random_sample(self, n: int = 1) -> list[dict[str, float | NDArray[Float]]]:
"""Generate a random sample of parameters from the target space.

Parameters
Expand Down Expand Up @@ -442,13 +442,13 @@ def save_state(self, path: str | PathLike[str]) -> None:
"""
random_state = None
if self._random_state is not None:
state_tuple = self._random_state.get_state()
state_dict = self._random_state.get_state(legacy=False)
random_state = {
"bit_generator": state_tuple[0],
"state": state_tuple[1].tolist(),
"pos": state_tuple[2],
"has_gauss": state_tuple[3],
"cached_gaussian": state_tuple[4],
"bit_generator": state_dict["bit_generator"],
"state": state_dict["state"]["key"].tolist(),
"pos": state_dict["state"]["pos"],
"has_gauss": state_dict["has_gauss"],
"cached_gaussian": state_dict["gauss"],
}

# Get constraint values if they exist
Expand Down