From 6e4c7cd05a10f0c0ee8e14d09ed7cedff8cb0695 Mon Sep 17 00:00:00 2001 From: moskomule Date: Wed, 16 Feb 2022 11:54:15 +0900 Subject: [PATCH 1/3] add cmaes --- evojax/algo/__init__.py | 3 +- evojax/algo/base.py | 21 +++- evojax/algo/cmaes.py | 238 ++++++++++++++++++++++++++++++++++++++++ evojax/algo/pgpe.py | 16 +-- 4 files changed, 259 insertions(+), 19 deletions(-) create mode 100644 evojax/algo/cmaes.py diff --git a/evojax/algo/__init__.py b/evojax/algo/__init__.py index 3f24162c..b4e87bc3 100644 --- a/evojax/algo/__init__.py +++ b/evojax/algo/__init__.py @@ -15,6 +15,7 @@ from .base import NEAlgorithm from .cma_wrapper import CMA from .pgpe import PGPE +from .cmaes import CMAES -__all__ = ['NEAlgorithm', 'CMA', 'PGPE'] +__all__ = ['NEAlgorithm', 'CMA', 'PGPE', 'CMAES'] diff --git a/evojax/algo/base.py b/evojax/algo/base.py index 825ea116..031471e3 100644 --- a/evojax/algo/base.py +++ b/evojax/algo/base.py @@ -12,11 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod +from functools import partial from typing import Union -import numpy as np + +import jax import jax.numpy as jnp +import numpy as np + + +@partial(jax.jit, static_argnums=(1,)) +def process_scores(x: Union[np.ndarray, jnp.ndarray], use_ranking: bool) -> jnp.ndarray: + """Convert fitness scores to rank if necessary.""" + + x = jnp.array(x) + if use_ranking: + ranks = jnp.zeros(x.size, dtype=int) + ranks = ranks.at[x.argsort()].set(jnp.arange(x.size)).reshape(x.shape) + return ranks / ranks.max() - 0.5 + else: + return x class NEAlgorithm(ABC): diff --git a/evojax/algo/cmaes.py b/evojax/algo/cmaes.py new file mode 100644 index 00000000..8688e928 --- /dev/null +++ b/evojax/algo/cmaes.py @@ -0,0 +1,238 @@ +""" Implementation of CMA-ES in JAX. + +Ref: https://github.com/CyberAgentAILab/cmaes/blob/main/cmaes/_cma.py +""" + +from __future__ import annotations + +import functools +import logging +import math + +import jax +import numpy as np +from jax import numpy as jnp + +from evojax.algo.base import NEAlgorithm, process_scores +from evojax.util import create_logger + +EPS = 1e-8 +MAX = 1e32 + + +class CMAES(NEAlgorithm): + """CMA-ES + + Ref: https://arxiv.org/abs/1604.00772 + Ref: https://github.com/CyberAgentAILab/cmaes/ + """ + + def __init__( + self, + pop_size: int = None, + param_size: int = None, + init_params: np.ndarray | jnp.ndarray = None, + init_sigma: float = 0.1, + init_cov: np.ndarray | jnp.ndarray = None, + solution_ranking: bool = True, + seed: int = 0, + logger: logging.Logger = None, + ): + """Initialization function. Equation numbers are from Hansen's tutorial. + + Args: + pop_size - Population size, recommended population size if not given. + param_size - Parameter size. + init_params - Initial parameters, all zeros if not given. + init_sigma - Initial sigma value. + init_cov - Intial covariance matrix, identity if not given. + solution_ranking - Should we treat the fitness as rankings or not. + seed - Random seed for parameters sampling. + """ + + assert init_sigma > 0 + if logger is None: + self._logger = create_logger("cmaes") + else: + self._logger = logger + + if init_params is None: + init_params = jnp.zeros(param_size) + mean = init_params + + if init_cov is None: + self._C = jnp.eye(param_size) + else: + self._C = init_cov + + if pop_size is None: + # eq (48) + pop_size = 4 + math.floor(3 * math.log(param_size)) + self._logger.info(f"population size (pop_size) is set to {pop_size} (recommended size)") + + mu = pop_size // 2 + + # eq (49) + weights_prime = jnp.array([math.log((pop_size + 1) / 2) - math.log1p(i) for i in range(pop_size)]) + mu_eff = (jnp.sum(weights_prime[:mu]) ** 2) / jnp.sum(weights_prime[:mu] ** 2) + mu_eff_minus = (jnp.sum(weights_prime[mu:]) ** 2) / jnp.sum(weights_prime[mu:] ** 2) + + # learning rate for rank-one update eq (57) + alpha_cov = 2 + c_1 = alpha_cov / ((param_size + 1.3) ** 2 + mu_eff) + + # learning rate for rank-mu update # eq (58) + c_mu = min( + 1 - c_1 - 1e-8, + alpha_cov * (mu_eff - 2 + 1 / mu_eff) / ((param_size + 2) ** 2 + alpha_cov * mu_eff / 2), + ) + + assert c_1 <= 1 - c_mu + assert c_mu <= 1 - c_1 + + min_alpha = min( + 1 + c_1 / c_mu, # eq (50) + 1 + (2 * mu_eff_minus) / (mu_eff + 2), # eq (51) + (1 - c_1 - c_mu) / (param_size * c_mu), # eq (52) + ) + + # eq (53) + positive_sum = jnp.sum(weights_prime[weights_prime > 0]) + negative_sum = jnp.sum(jnp.abs(weights_prime[weights_prime < 0])) + weights = jnp.where( + weights_prime >= 0, + 1 / positive_sum * weights_prime, + min_alpha / negative_sum * weights_prime, + ) + c_m = 1 # eq (54) + + # learning rate for the cumulation for the step-size control, eq (55) + c_sigma = (mu_eff + 2) / (param_size + mu_eff + 5) + d_sigma = 1 + 2 * max(0, math.sqrt((mu_eff - 1) / (param_size + 1)) - 1) + c_sigma + assert c_sigma < 1 + + # learning rate for cumulation for the rank-one update, eq (56) + c_c = (4 + mu_eff / param_size) / (param_size + 4 + 2 * mu_eff / param_size) + assert c_c <= 1 + + self._n_dim = param_size + self.pop_size = pop_size + self._mu = mu + self._mu_eff = mu_eff + self._c_c = c_c + self._c_1 = c_1 + self._c_mu = c_mu + self._c_sigma = c_sigma + self._d_sigma = d_sigma + self._c_m = c_m + + # approx of E||N(0, I)|| + self._chi_n = math.sqrt(param_size) * (1 - (1 / (4 * param_size)) + 1 / (21 * param_size**2)) + + self._weights = weights + + # path + self._p_sigma = jnp.zeros(param_size) + self._pc = jnp.zeros(param_size) + + self._mean = mean + self._sigma = init_sigma + self._D = None + self._B = None + self._solutions = None + + self._t = 0 + self._solution_ranking = solution_ranking + self._key = jax.random.PRNGKey(seed=seed) + + def _eigen_decomposition(self) -> tuple[jnp.ndarray, jnp.ndarray]: + if self._B is None or self._D is None: + self._C, self._B, self._D = _eigen_decomposition(self._C) + + return self._B, self._D + + def ask(self) -> jnp.ndarray: + # resampling is skipped in this implementation + # see cmaes for more details + B, D = self._eigen_decomposition() + self._key, key = jax.random.split(self._key) + z = jax.random.normal(key, (self.pop_size, self._n_dim)) + self._solutions = _ask_impl(z, B, D, self._mean, self._sigma) + return self._solutions + + def tell(self, fitness: np.ndarray | jnp.ndarray) -> None: + + fitness_scores = process_scores(fitness, self._solution_ranking) + self._t += 1 + B, D = self._eigen_decomposition() + self._B, self._D = None, None + + # highest score, ..., lowest score + idx = jnp.argsort(-fitness_scores) + x_k = self._solutions[idx] + y_k = (x_k - self._mean) / self._sigma + + # selection and recombination + y_w = jnp.sum(y_k[: self._mu].T * self._weights[: self._mu], axis=1) # eq (41) + self._mean += self._c_m * self._sigma * y_w # eq (42) + + # step-size control + C_2 = B.dot(jnp.diag(1 / D)).dot(B.T) # C^{-0.5} + self._p_sigma = (1 - self._c_sigma) * self._p_sigma + jnp.sqrt( # eq (43) + self._c_sigma * (2 - self._c_sigma) * self._mu_eff + ) * C_2.dot(y_w) + + norm_p_sigma = jnp.linalg.norm(self._p_sigma) + self._sigma *= jnp.minimum( + jnp.exp(self._c_sigma / self._d_sigma * (norm_p_sigma / self._chi_n - 1)), + MAX, + ) # eq (44) + + # covariance matrix adaptation (p. 28) + h_sigma_cond_left = norm_p_sigma / jnp.sqrt((1 - (1 - self._c_sigma) ** (2 * (self._t + 1)))) + h_sigma_cond_right = (1.4 + 2 / (self._n_dim + 1)) * self._chi_n + h_sigma = 1.0 if h_sigma_cond_left < h_sigma_cond_right else 0.0 + + self._pc = (1 - self._c_c) * self._pc + h_sigma * jnp.sqrt( + self._c_c * (2 - self._c_c) * self._mu_eff + ) * y_w # eq (45) + + w_io = self._weights * jnp.where( + self._weights >= 0, + 1, + self._n_dim / (jnp.linalg.norm(C_2.dot(y_k.T), axis=0) ** 2 + EPS), + ) # eq (46) + + delta_h_sigma = (1 - h_sigma) * self._c_c * (2 - self._c_c) + assert delta_h_sigma <= 1 + + rank_one = jnp.outer(self._pc, self._pc) + rank_mu = jnp.sum(jnp.array([w * jnp.outer(y, y) for w, y in zip(w_io, y_k)]), axis=0) + self._C = ( + (1 + self._c_1 * delta_h_sigma - self._c_1 - self._c_mu * jnp.sum(self._weights)) * self._C + + self._c_1 * rank_one + + self._c_mu * rank_mu + ) # eq (47) + + @property + def best_params(self) -> jnp.ndarray: + return jnp.array(self._mean, copy=True) + + @best_params.setter + def best_params(self, params: np.ndarray | jnp.ndarray) -> None: + self._mean = jnp.array(params, copy=True) + + +@jax.jit +@functools.partial(jax.vmap, in_axes=(0, None, None, None, None)) +def _ask_impl(z, b, d, mean, sigma) -> jnp.ndarray: + y = b.dot(jnp.diag(d)).dot(z) # ~N(0, C) + return mean + sigma * y + + +@jax.jit +def _eigen_decomposition(c): + c = (c + c.T) / 2 + d2, b = jnp.linalg.eigh(c) + d = jnp.where(d2 < 0, EPS, d2) + return b.dot(jnp.diag(d)).dot(b.T), b, jnp.sqrt(d) diff --git a/evojax/algo/pgpe.py b/evojax/algo/pgpe.py index ac8a9dbe..f2e84d76 100644 --- a/evojax/algo/pgpe.py +++ b/evojax/algo/pgpe.py @@ -32,24 +32,10 @@ except ModuleNotFoundError: from jax.experimental import optimizers -from evojax.algo.base import NEAlgorithm +from evojax.algo.base import NEAlgorithm, process_scores from evojax.util import create_logger -@partial(jax.jit, static_argnums=(1,)) -def process_scores(x: Union[np.ndarray, jnp.ndarray], - use_ranking: bool) -> jnp.ndarray: - """Convert fitness scores to rank if necessary.""" - - x = jnp.array(x) - if use_ranking: - ranks = jnp.zeros(x.size, dtype=int) - ranks = ranks.at[x.argsort()].set(jnp.arange(x.size)).reshape(x.shape) - return ranks / ranks.max() - 0.5 - else: - return x - - @jax.jit def compute_reinforce_update( fitness_scores: jnp.ndarray, From 3cf7c565e4a59f5eaf123f1d83e1fdbf9076161a Mon Sep 17 00:00:00 2001 From: moskomule Date: Thu, 17 Feb 2022 19:00:25 +0900 Subject: [PATCH 2/3] fix names --- evojax/algo/__init__.py | 6 +++--- evojax/algo/cma_wrapper.py | 2 +- evojax/algo/cmaes.py | 2 +- examples/train_cartpole.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/evojax/algo/__init__.py b/evojax/algo/__init__.py index b4e87bc3..84499444 100644 --- a/evojax/algo/__init__.py +++ b/evojax/algo/__init__.py @@ -13,9 +13,9 @@ # limitations under the License. from .base import NEAlgorithm -from .cma_wrapper import CMA +from .cma_wrapper import CMAES_OriginalCPU from .pgpe import PGPE -from .cmaes import CMAES +from .cmaes import CMAES_CyberAgent -__all__ = ['NEAlgorithm', 'CMA', 'PGPE', 'CMAES'] +__all__ = ['NEAlgorithm', 'CMAES_OriginalCPU', 'PGPE', 'CMAES_CyberAgent'] diff --git a/evojax/algo/cma_wrapper.py b/evojax/algo/cma_wrapper.py index bf9fc0ba..e46d5704 100644 --- a/evojax/algo/cma_wrapper.py +++ b/evojax/algo/cma_wrapper.py @@ -31,7 +31,7 @@ from evojax.util import create_logger -class CMA(NEAlgorithm): +class CMAES_OriginalCPU(NEAlgorithm): """A wrapper of CMA-ES.""" def __init__(self, diff --git a/evojax/algo/cmaes.py b/evojax/algo/cmaes.py index 8688e928..91adadda 100644 --- a/evojax/algo/cmaes.py +++ b/evojax/algo/cmaes.py @@ -20,7 +20,7 @@ MAX = 1e32 -class CMAES(NEAlgorithm): +class CMAES_CyberAgent(NEAlgorithm): """CMA-ES Ref: https://arxiv.org/abs/1604.00772 diff --git a/examples/train_cartpole.py b/examples/train_cartpole.py index c0ccfb2c..68b630ea 100644 --- a/examples/train_cartpole.py +++ b/examples/train_cartpole.py @@ -38,7 +38,7 @@ from evojax.policy import MLPPolicy from evojax.policy import PermutationInvariantPolicy from evojax.algo import PGPE -from evojax.algo import CMA +from evojax.algo import CMAES_OriginalCPU from evojax import util @@ -105,7 +105,7 @@ def main(config): output_dim=train_task.act_shape[0], ) if config.cma: - solver = CMA( + solver = CMAES_OriginalCPU( pop_size=config.pop_size, param_size=policy.num_params, init_stdev=config.init_std, From 9f7ad075df035ffa47adc0f73c8f14ef7e1673c5 Mon Sep 17 00:00:00 2001 From: moskomule Date: Thu, 17 Feb 2022 19:36:34 +0900 Subject: [PATCH 3/3] use consistent naming to cma --- evojax/algo/cmaes.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/evojax/algo/cmaes.py b/evojax/algo/cmaes.py index 91adadda..6e0e2749 100644 --- a/evojax/algo/cmaes.py +++ b/evojax/algo/cmaes.py @@ -11,10 +11,9 @@ import jax import numpy as np -from jax import numpy as jnp - from evojax.algo.base import NEAlgorithm, process_scores from evojax.util import create_logger +from jax import numpy as jnp EPS = 1e-8 MAX = 1e32 @@ -32,7 +31,7 @@ def __init__( pop_size: int = None, param_size: int = None, init_params: np.ndarray | jnp.ndarray = None, - init_sigma: float = 0.1, + init_stdev: float = 0.1, init_cov: np.ndarray | jnp.ndarray = None, solution_ranking: bool = True, seed: int = 0, @@ -44,13 +43,13 @@ def __init__( pop_size - Population size, recommended population size if not given. param_size - Parameter size. init_params - Initial parameters, all zeros if not given. - init_sigma - Initial sigma value. + init_stdev - Initial sigma value. init_cov - Intial covariance matrix, identity if not given. solution_ranking - Should we treat the fitness as rankings or not. seed - Random seed for parameters sampling. """ - assert init_sigma > 0 + assert init_stdev > 0 if logger is None: self._logger = create_logger("cmaes") else: @@ -136,7 +135,7 @@ def __init__( self._pc = jnp.zeros(param_size) self._mean = mean - self._sigma = init_sigma + self._sigma = init_stdev self._D = None self._B = None self._solutions = None