diff --git a/evojax/policy/__init__.py b/evojax/policy/__init__.py index a924c4da..c3496494 100644 --- a/evojax/policy/__init__.py +++ b/evojax/policy/__init__.py @@ -17,7 +17,7 @@ from .mlp_pi import PermutationInvariantPolicy from .convnet import ConvNetPolicy from .seq2seq import Seq2seqPolicy - +from .fast_learner import FastLearner __all__ = ['PolicyNetwork', 'MLPPolicy', 'PermutationInvariantPolicy', - 'ConvNetPolicy', 'Seq2seqPolicy'] + 'ConvNetPolicy', 'Seq2seqPolicy', 'FastLearner'] diff --git a/evojax/policy/fast_learner.py b/evojax/policy/fast_learner.py new file mode 100644 index 00000000..8510508e --- /dev/null +++ b/evojax/policy/fast_learner.py @@ -0,0 +1,116 @@ + +import logging +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import random +from jax.tree_util import tree_map +from flax import linen as nn + +from evojax.policy.base import PolicyNetwork +from evojax.policy.base import PolicyState +from evojax.task.base import TaskState +from evojax.util import create_logger +from evojax.util import get_params_format_fn + + +class CNN(nn.Module): + """CNN for MNIST.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=8, kernel_size=(3, 3), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=16, kernel_size=(3, 3), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + + x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + + x = x.reshape((*x.shape[:-3], -1)) # flatten + x = nn.Dense(features=32)(x) + x = nn.Dense(features=5)(x) + x = nn.log_softmax(x) + return x + +def get_update_params_fn(apply_fn): + # this function could be something much more complicated other than vanilla sgd... + # params has shape [pop_size, ...] + # x has shape [pop_size, batch_size, ways * shots, 28, 28 ,1] + # y has shape [pop_size, batch_size, ways * shots, 1] + lr = 0.4 + # loss for one param and one set + def loss_fn(params, x, y): + pred = apply_fn(params, x) # [ways * shots, n_classes] + return -jnp.take_along_axis(pred, y, axis=-1).mean() + + def update_one_set(params, x, y): + grad = jax.grad(loss_fn)(params, x, y) + new_params = tree_map(lambda p,g : p - lr * g, params, grad) + return new_params + + update_multi_set = jax.vmap(update_one_set, in_axes=[0,0,0], out_axes=0) # maps over pop_size + update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[None,1,1], out_axes=1) # maps over batch_size + # return update_multi_params_multi_set(params, x, y) + subsequent_update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[1,1,1], out_axes=1) + return update_multi_params_multi_set, subsequent_update_multi_params_multi_set + + +class FastLearner(PolicyNetwork): + """A convolutional neural network for the MNIST classification task.""" + + def __init__(self, + n_classes: int, + num_grad_steps: int, + logger: logging.Logger = None, + ): + if logger is None: + self._logger = create_logger('FastLearner') + else: + self._logger = logger + assert num_grad_steps > 0 + self.num_grad_steps = num_grad_steps + model = CNN() + params = model.init(random.PRNGKey(0), jnp.zeros([1, 28, 28, 1])) + self.num_params, format_params_fn = get_params_format_fn(params) + self._logger.info( + 'FastLearner.num_params = {}'.format(self.num_params)) + self._format_params_fn = jax.vmap(format_params_fn) # this maps over members + self._model_apply = model.apply + # self._forward_fn = jax.vmap(model.apply) # this maps over members + self._update_fn, self._subsequent_update_fn = get_update_params_fn(self._model_apply) + # lr = 1e-3 + # _loss_fn = partial(loss_fn, self._forward_fn, n_classes) + # def update_params_one_set(params, x, y): + # grads = jax.grad(_loss_fn)(params, x, y) + # new_params = params - lr * grads + # return new_params + + # update_params_multi_set = jax.vmap(update_params_one_set, in_axes=[None,0,0], out_axes=[0]) + # update_params_pop_multi_set = jax.vmap(update_params_multi_set, in_axes=[0,None,None], out_axes=[0]) + + def get_actions(self, + t_states: TaskState, + params: jnp.ndarray, + p_states: PolicyState) -> Tuple[jnp.ndarray, PolicyState]: + params = self._format_params_fn(params) + x, y = t_states.obs, t_states.labels + # x shape: [pop_size, batch_size, ways x shots, 28, 28, 1] + # y shape: [pop_size, batch_size, ways x shots] + # params shape: [pop_size, ...] + # updated_params = update_params(self._model_apply, params, x, y) + updated_params = self._update_fn(params, x, y) + # updated_params shape: [pop_size, batch_size, ...] + for _ in range(self.num_grad_steps-1): + updated_params = self._subsequent_update_fn(updated_params, x, y) + + apply_batched_params = jax.vmap(jax.vmap(self._model_apply)) + # the output prediction should have shape: [pop_size, batch_size, ways * shots, n_classes] + return apply_batched_params(updated_params, t_states.test_obs), p_states diff --git a/evojax/task/few_shot_omniglot.py b/evojax/task/few_shot_omniglot.py new file mode 100644 index 00000000..8dfd8292 --- /dev/null +++ b/evojax/task/few_shot_omniglot.py @@ -0,0 +1,122 @@ +# Copyright 2022 The EvoJAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import random +from flax.struct import dataclass + +from evojax.task.base import VectorizedTask +from evojax.task.base import TaskState + + + +### step_fn and reset_fn defined here are meant for a ___single___ policy network +### which will then be mapped to all the members using vmap +### such that all members take as input the same observation +@dataclass +class State(TaskState): + obs: jnp.ndarray + labels: jnp.ndarray + test_obs: jnp.ndarray + test_labels: jnp.ndarray + +def loss(prediction: jnp.ndarray, labels: jnp.ndarray) -> jnp.float32: + # target = jax.nn.one_hot(target, ways) + # return -jnp.mean(jnp.sum(prediction * target, axis=1)) + return -jnp.take_along_axis(prediction, labels, axis=-1).mean() + +def accuracy(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.float32: + predicted_class = jnp.argmax(prediction, axis=-1, keepdims=True) + return jnp.mean(predicted_class == target) + + +class Omniglot(VectorizedTask): + """Omniglot few-show learning classification task.""" + + def __init__(self, + batch_size: int = 16, + test: bool = False): + num_workers = 4 + test_shots = 15 + shots = 5 + self.ways = 5 + self.max_steps = 1 + self.obs_shape = tuple([28, 28, 1]) + self.act_shape = tuple([self.ways, ]) + + + # Delayed importing of torchmeta + try: + # TODO: torchmeta seems to be unmaintained and does not support newer versions of torch + # maybe replace this with processing code from https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch + from torchmeta.datasets.helpers import omniglot + from torchmeta.utils.data import BatchMetaDataLoader + except ModuleNotFoundError: + print('You need to install torchmeta for this task.') + print(' pip install torchmeta') + sys.exit(1) + + dataset = omniglot("data", ways=self.ways, shots=shots, test_shots=test_shots, + shuffle=not test, meta_train=not test, meta_test=test, download=True) + dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + iterator = iter(dataloader) + + def reset_fn(key): + # try: + # batch = iterator.next() + # except StopIteration: + # iterator = iter(dataloader) + # batch = iterator.next() + batch = iterator.next() + # train_inputs shape: (batch_size, ways x shots, 1, 28, 28) + # train_labels shape: (batch_size, ways x shots) + train_inputs, train_labels = batch['train'] + train_inputs = jnp.transpose(jnp.array(train_inputs), (0,1,3,4,2)) + train_inputs = (train_inputs - train_inputs.mean(axis=1, keepdims=True)) / (train_inputs.std(axis=1, keepdims=True) + 1e-8) + train_labels = jnp.array(train_labels.unsqueeze(-1)) + test_inputs, test_labels = batch['test'] + test_inputs = jnp.transpose(jnp.array(test_inputs), (0,1,3,4,2)) + test_inputs = (test_inputs - test_inputs.mean(axis=1, keepdims=True)) / (test_inputs.std(axis=1, keepdims=True) + 1e-8) + test_labels = jnp.array(test_labels.unsqueeze(-1)) + # the shape of this State is meant for ___one___ of the members of the population + return State(obs=train_inputs, labels=train_labels, test_obs=test_inputs, test_labels=test_labels) + + # this vmap is for pop_size (so all members see the same state) + self._reset_fn = jax.jit(jax.vmap(reset_fn)) + + def step_fn(state, action): + # state: state returned by reset_fn + # action: predictions from ___one___ of the members in the population + # should have the shape of [batch_size, ways x shots, ways (n_classes)] + if test: + reward = accuracy(action, state.test_labels) + else: + reward = -loss(action, state.test_labels) + return state, reward, jnp.ones(()) + + # this vmap is for pop_size (so all members output actions given the same state) + self._step_fn = jax.jit(jax.vmap(step_fn)) + + def reset(self, key: jnp.ndarray) -> State: + return self._reset_fn(key) + + def step(self, + state: TaskState, + action: jnp.ndarray) -> Tuple[TaskState, jnp.ndarray, jnp.ndarray]: + return self._step_fn(state, action) diff --git a/examples/train_fast_learner.py b/examples/train_fast_learner.py new file mode 100644 index 00000000..ab789a33 --- /dev/null +++ b/examples/train_fast_learner.py @@ -0,0 +1,115 @@ +# Copyright 2022 The EvoJAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train an agent for MNIST classification. + +Example command to run this script: `python train_mnist.py --gpu-id=0` +""" + +import argparse +import os +import shutil + +from evojax import Trainer +from evojax.task.few_shot_omniglot import Omniglot +from evojax.policy import FastLearner +from evojax.algo import PGPE +from evojax import util + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--pop-size', type=int, default=32, help='NE population size.') + parser.add_argument( + '--batch-size', type=int, default=32, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)') + parser.add_argument( + '--num-grad-steps', type=int, default=3, help='# of gradient steps') + parser.add_argument( + '--max-iter', type=int, default=20000, help='Max training iterations.') + parser.add_argument( + '--test-interval', type=int, default=1000, help='Test interval.') + parser.add_argument( + '--log-interval', type=int, default=100, help='Logging interval.') + parser.add_argument( + '--seed', type=int, default=42, help='Random seed for training.') + parser.add_argument( + '--center-lr', type=float, default=0.003, help='Center learning rate.') + parser.add_argument( + '--std-lr', type=float, default=0.089, help='Std learning rate.') + parser.add_argument( + '--init-std', type=float, default=0.039, help='Initial std.') + parser.add_argument( + '--gpu-id', type=str, help='GPU(s) to use.') + parser.add_argument( + '--debug', action='store_true', help='Debug mode.') + config, _ = parser.parse_known_args() + return config + + +def main(config): + log_dir = './log/omniglot' + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + logger = util.create_logger( + name='Omniglot', log_dir=log_dir, debug=config.debug) + logger.info('EvoJAX Omniglot Demo') + logger.info('=' * 30) + + ways = n_classes = 5 + + policy = FastLearner(n_classes=n_classes, num_grad_steps=config.num_grad_steps, logger=logger) + train_task = Omniglot(batch_size=config.batch_size, test=False) + test_task = Omniglot(batch_size=config.batch_size, test=True) + solver = PGPE( + pop_size=config.pop_size, + param_size=policy.num_params, + optimizer='adam', + center_learning_rate=config.center_lr, + stdev_learning_rate=config.std_lr, + init_stdev=config.init_std, + logger=logger, + seed=config.seed, + ) + + # Train. + trainer = Trainer( + policy=policy, + solver=solver, + train_task=train_task, + test_task=test_task, + max_iter=config.max_iter, + log_interval=config.log_interval, + test_interval=config.test_interval, + n_repeats=1, + n_evaluations=1, + seed=config.seed, + log_dir=log_dir, + logger=logger, + ) + trainer.run(demo_mode=False) + + # Test the final model. + src_file = os.path.join(log_dir, 'best.npz') + tar_file = os.path.join(log_dir, 'model.npz') + shutil.copy(src_file, tar_file) + trainer.model_dir = log_dir + trainer.run(demo_mode=True) + + +if __name__ == '__main__': + configs = parse_args() + if configs.gpu_id is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id + main(configs)