From ee911b6fe278d3312641f1cf0c151b1528d4f22a Mon Sep 17 00:00:00 2001 From: Rujikorn Charakorn Date: Fri, 8 Sep 2023 14:16:22 +0200 Subject: [PATCH 1/6] add few-shot omniglot task and corspnd policy --- evojax/policy/__init__.py | 4 +- evojax/policy/fast_learner.py | 98 ++++++++++++++++++++++++++ evojax/task/few_shot_omniglot.py | 117 +++++++++++++++++++++++++++++++ examples/train_fast_learner.py | 112 +++++++++++++++++++++++++++++ 4 files changed, 329 insertions(+), 2 deletions(-) create mode 100644 evojax/policy/fast_learner.py create mode 100644 evojax/task/few_shot_omniglot.py create mode 100644 examples/train_fast_learner.py diff --git a/evojax/policy/__init__.py b/evojax/policy/__init__.py index a924c4da..c3496494 100644 --- a/evojax/policy/__init__.py +++ b/evojax/policy/__init__.py @@ -17,7 +17,7 @@ from .mlp_pi import PermutationInvariantPolicy from .convnet import ConvNetPolicy from .seq2seq import Seq2seqPolicy - +from .fast_learner import FastLearner __all__ = ['PolicyNetwork', 'MLPPolicy', 'PermutationInvariantPolicy', - 'ConvNetPolicy', 'Seq2seqPolicy'] + 'ConvNetPolicy', 'Seq2seqPolicy', 'FastLearner'] diff --git a/evojax/policy/fast_learner.py b/evojax/policy/fast_learner.py new file mode 100644 index 00000000..d804ca9d --- /dev/null +++ b/evojax/policy/fast_learner.py @@ -0,0 +1,98 @@ + +import logging +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import random +from flax import linen as nn + +from evojax.policy.base import PolicyNetwork +from evojax.policy.base import PolicyState +from evojax.task.base import TaskState +from evojax.util import create_logger +from evojax.util import get_params_format_fn + + +class CNN(nn.Module): + """CNN for MNIST.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=8, kernel_size=(5, 5), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=16, kernel_size=(5, 5), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[:-3], -1)) # flatten + x = nn.Dense(features=n_classes)(x) + x = nn.log_softmax(x) + return x + +def update_params(apply_fn, params, x, y): + # this function could be something much more complicated other than vanilla sgd... + # params has shape [pop_size, ...] + # x has shape [batch_size, ways * shots, 28, 28 ,1] + # y has shape [batch_size, ways * shots, 1] + + # loss for one param and one set + def loss_fn(params, x, y): + pred = apply_fn(params, x) # [ways * shots, n_classes] + return -jnp.take_along_axis(pred, labels, axis=-1).mean() + + def update_one_set(params, x, y): + grad = jax.grad(loss_fn)(params, x, y) + new_params = tree_map(lambda p,g : p - lr * g, params, grad) + return new_params + + update_multi_set = jax.vmap(update_one_set, in_axes=[None,0,0], out_axes=0) + update_multi_params_multi_set = jax.vmap(update_multi_set, in_axes=[0,None,None], out_axes=0) + return update_multi_params_multi_set(params, x, y) + + +class FastLearner(PolicyNetwork): + """A convolutional neural network for the MNIST classification task.""" + + def __init__(self, + n_classes: int, + logger: logging.Logger = None, + ): + if logger is None: + self._logger = create_logger('FastLearner') + else: + self._logger = logger + + model = CNN() + params = model.init(random.PRNGKey(0), jnp.zeros([1, 28, 28, 1])) + self.num_params, format_params_fn = get_params_format_fn(params) + self._logger.info( + 'FastLearner.num_params = {}'.format(self.num_params)) + self._format_params_fn = jax.vmap(format_params_fn) # this maps over members + self._forward_fn = jax.vmap(model.apply) # this maps over members + + # lr = 1e-3 + # _loss_fn = partial(loss_fn, self._forward_fn, n_classes) + # def update_params_one_set(params, x, y): + # grads = jax.grad(_loss_fn)(params, x, y) + # new_params = params - lr * grads + # return new_params + + # update_params_multi_set = jax.vmap(update_params_one_set, in_axes=[None,0,0], out_axes=[0]) + # update_params_pop_multi_set = jax.vmap(update_params_multi_set, in_axes=[0,None,None], out_axes=[0]) + + def get_actions(self, + t_states: TaskState, + params: jnp.ndarray, + p_states: PolicyState) -> Tuple[jnp.ndarray, PolicyState]: + params = self._format_params_fn(params) + x, y = t_state.obs + # x shape: [pop_size, batch_size, ways x shots, 28, 28, 1] + # y shape: [pop_size, batch_size, ways x shots] + # params shape: [pop_size, ...] + updated_params = update_params(params, x, y) + # updated_params shape: [pop_size, batch_size, ...] + + apply_batched_params = jax.vmap(model.apply, in_axes=[0,None]) + # the output prediction should have shape: [pop_size, batch_size, ways * shots, n_classes] + return apply_batched_params(updated_params, t_states.test_inputs), p_states diff --git a/evojax/task/few_shot_omniglot.py b/evojax/task/few_shot_omniglot.py new file mode 100644 index 00000000..303de5b6 --- /dev/null +++ b/evojax/task/few_shot_omniglot.py @@ -0,0 +1,117 @@ +# Copyright 2022 The EvoJAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import random +from flax.struct import dataclass + +from evojax.task.base import VectorizedTask +from evojax.task.base import TaskState + + + +### step_fn and reset_fn defined here are meant for a ___single___ policy network +### which will then be mapped to all the members using vmap +### such that all members take as input the same observation +@dataclass +class State(TaskState): + obs: jnp.ndarray + labels: jnp.ndarray + +def loss(prediction: jnp.ndarray, labels: jnp.ndarray) -> jnp.float32: + # target = jax.nn.one_hot(target, ways) + # return -jnp.mean(jnp.sum(prediction * target, axis=1)) + return -jnp.take_along_axis(prediction, labels, axis=-1).mean() + +def accuracy(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.float32: + predicted_class = jnp.argmax(prediction, axis=1) + return jnp.mean(predicted_class == target) + + +class Omniglot(VectorizedTask): + """Omniglot few-show learning classification task.""" + + def __init__(self, + batch_size: int = 16, + test: bool = False): + num_workers = 4 + test_shots = 15 + shots = 5 + self.ways = 5 + self.max_steps = 1 + self.obs_shape = tuple([28, 28, 1]) + self.act_shape = tuple([self.ways, ]) + + + # Delayed importing of torchmeta + try: + # TODO: torchmeta seems to be unmaintained and does not support newer versions of torch + # maybe replace this with processing code from https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch + from torchmeta.datasets.helpers import omniglot + from torchmeta.utils.data import BatchMetaDataLoader + except ModuleNotFoundError: + print('You need to install torchmeta for this task.') + print(' pip install torchmeta') + sys.exit(1) + + dataset = omniglot("data", ways=self.ways, shots=shots, test_shots=test_shots, + shuffle=not test, meta_train=not test, meta_test=test, download=True) + dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + iterator = iter(dataloader) + + def reset_fn(key): + try: + batch = iterator.next() + except StopIteration: + iterator = iter(dataloader) + batch = iterator.next() + # train_inputs shape: (batch_size, ways x shots, 1, 28, 28) + # train_labels shape: (batch_size, ways x shots) + train_inputs, train_labels = batch['train'] + train_inputs = jnp.transpose(jnp.array(train_inputs), (0,1,3,4,2)) + train_labels = jnp.array(train_labels.unsqueeze(-1)) + test_inputs, test_labels = batch['test'] + test_inputs = jnp.transpose(jnp.array(test_inputs), (0,1,3,4,2)) + test_labels = jnp.array(test_labels.unsqueeze(-1)) + # the shape of this State is meant for ___one___ of the members of the population + return State(obs=(train_inputs, train_labels), test_inputs=test_inputs, test_labels=test_labels) + + # this vmap is for pop_size (so all members see the same state) + self._reset_fn = jax.jit(jax.vmap(reset_fn)) + + def step_fn(state, action): + # state: state returned by reset_fn + # action: predictions from ___one___ of the members in the population + # should have the shape of [batch_size, ways x shots, ways (n_classes)] + if test: + reward = accuracy(action, state.test_labels) + else: + reward = -loss(action, state.test_labels) + return state, reward, jnp.ones(()) + + # this vmap is for pop_size (so all members output actions given the same state) + self._step_fn = jax.jit(jax.vmap(step_fn)) + + def reset(self, key: jnp.ndarray) -> State: + return self._reset_fn(key) + + def step(self, + state: TaskState, + action: jnp.ndarray) -> Tuple[TaskState, jnp.ndarray, jnp.ndarray]: + return self._step_fn(state, action) diff --git a/examples/train_fast_learner.py b/examples/train_fast_learner.py new file mode 100644 index 00000000..eebf2c1b --- /dev/null +++ b/examples/train_fast_learner.py @@ -0,0 +1,112 @@ +# Copyright 2022 The EvoJAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train an agent for MNIST classification. + +Example command to run this script: `python train_mnist.py --gpu-id=0` +""" + +import argparse +import os +import shutil + +from evojax import Trainer +from evojax.task.few_shot_omniglot import Omniglot +from evojax.policy import FastLearner +from evojax.algo import PGPE +from evojax import util + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--pop-size', type=int, default=64, help='NE population size.') + parser.add_argument( + '--batch-size', type=int, default=16, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)') + parser.add_argument( + '--max-iter', type=int, default=5000, help='Max training iterations.') + parser.add_argument( + '--test-interval', type=int, default=1000, help='Test interval.') + parser.add_argument( + '--log-interval', type=int, default=100, help='Logging interval.') + parser.add_argument( + '--seed', type=int, default=42, help='Random seed for training.') + parser.add_argument( + '--center-lr', type=float, default=0.006, help='Center learning rate.') + parser.add_argument( + '--std-lr', type=float, default=0.089, help='Std learning rate.') + parser.add_argument( + '--init-std', type=float, default=0.039, help='Initial std.') + parser.add_argument( + '--gpu-id', type=str, help='GPU(s) to use.') + parser.add_argument( + '--debug', action='store_true', help='Debug mode.') + config, _ = parser.parse_known_args() + return config + + +def main(config): + log_dir = './log/omniglot' + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + logger = util.create_logger( + name='Omniglot', log_dir=log_dir, debug=config.debug) + logger.info('EvoJAX Omniglot Demo') + logger.info('=' * 30) + + ways = n_classes = 5 + policy = FastLearner(n_classes=n_classes, logger=logger) + train_task = Omniglot(batch_size=config.batch_size, test=False) + test_task = Omniglot(batch_size=config.batch_size, test=True) + solver = PGPE( + pop_size=config.pop_size, + param_size=policy.num_params, + optimizer='adam', + center_learning_rate=config.center_lr, + stdev_learning_rate=config.std_lr, + init_stdev=config.init_std, + logger=logger, + seed=config.seed, + ) + + # Train. + trainer = Trainer( + policy=policy, + solver=solver, + train_task=train_task, + test_task=test_task, + max_iter=config.max_iter, + log_interval=config.log_interval, + test_interval=config.test_interval, + n_repeats=1, + n_evaluations=1, + seed=config.seed, + log_dir=log_dir, + logger=logger, + ) + trainer.run(demo_mode=False) + + # Test the final model. + src_file = os.path.join(log_dir, 'best.npz') + tar_file = os.path.join(log_dir, 'model.npz') + shutil.copy(src_file, tar_file) + trainer.model_dir = log_dir + trainer.run(demo_mode=True) + + +if __name__ == '__main__': + configs = parse_args() + if configs.gpu_id is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id + main(configs) From ca1a1dc9bb1c682994e4f313991cbe4b1bb05841 Mon Sep 17 00:00:00 2001 From: 51616 Date: Fri, 8 Sep 2023 22:52:20 +0700 Subject: [PATCH 2/6] trainable meta learning omniglot --- evojax/policy/fast_learner.py | 28 ++++++++++++++-------------- evojax/task/few_shot_omniglot.py | 17 ++++++++++------- examples/train_fast_learner.py | 2 +- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/evojax/policy/fast_learner.py b/evojax/policy/fast_learner.py index d804ca9d..9992ae96 100644 --- a/evojax/policy/fast_learner.py +++ b/evojax/policy/fast_learner.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from jax import random +from jax.tree_util import tree_map from flax import linen as nn from evojax.policy.base import PolicyNetwork @@ -25,29 +26,28 @@ def __call__(self, x): x = nn.Conv(features=16, kernel_size=(5, 5), padding='SAME')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[:-3], -1)) # flatten - x = nn.Dense(features=n_classes)(x) + x = x.reshape((*x.shape[:-3], -1)) # flatten + x = nn.Dense(features=5)(x) x = nn.log_softmax(x) return x def update_params(apply_fn, params, x, y): # this function could be something much more complicated other than vanilla sgd... # params has shape [pop_size, ...] - # x has shape [batch_size, ways * shots, 28, 28 ,1] - # y has shape [batch_size, ways * shots, 1] - + # x has shape [pop_size, batch_size, ways * shots, 28, 28 ,1] + # y has shape [pop_size, batch_size, ways * shots, 1] + lr = 0.4 # loss for one param and one set def loss_fn(params, x, y): pred = apply_fn(params, x) # [ways * shots, n_classes] - return -jnp.take_along_axis(pred, labels, axis=-1).mean() + return -jnp.take_along_axis(pred, y, axis=-1).mean() def update_one_set(params, x, y): grad = jax.grad(loss_fn)(params, x, y) new_params = tree_map(lambda p,g : p - lr * g, params, grad) return new_params - - update_multi_set = jax.vmap(update_one_set, in_axes=[None,0,0], out_axes=0) - update_multi_params_multi_set = jax.vmap(update_multi_set, in_axes=[0,None,None], out_axes=0) + update_multi_set = jax.vmap(update_one_set, in_axes=[0,0,0], out_axes=0) + update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[None,1,1], out_axes=1) return update_multi_params_multi_set(params, x, y) @@ -69,6 +69,7 @@ def __init__(self, self._logger.info( 'FastLearner.num_params = {}'.format(self.num_params)) self._format_params_fn = jax.vmap(format_params_fn) # this maps over members + self._model_apply = model.apply self._forward_fn = jax.vmap(model.apply) # this maps over members # lr = 1e-3 @@ -86,13 +87,12 @@ def get_actions(self, params: jnp.ndarray, p_states: PolicyState) -> Tuple[jnp.ndarray, PolicyState]: params = self._format_params_fn(params) - x, y = t_state.obs + x, y = t_states.obs, t_states.labels # x shape: [pop_size, batch_size, ways x shots, 28, 28, 1] # y shape: [pop_size, batch_size, ways x shots] # params shape: [pop_size, ...] - updated_params = update_params(params, x, y) + updated_params = update_params(self._model_apply, params, x, y) # updated_params shape: [pop_size, batch_size, ...] - - apply_batched_params = jax.vmap(model.apply, in_axes=[0,None]) + apply_batched_params = jax.vmap(jax.vmap(self._model_apply)) # the output prediction should have shape: [pop_size, batch_size, ways * shots, n_classes] - return apply_batched_params(updated_params, t_states.test_inputs), p_states + return apply_batched_params(updated_params, t_states.test_obs), p_states diff --git a/evojax/task/few_shot_omniglot.py b/evojax/task/few_shot_omniglot.py index 303de5b6..a373d09d 100644 --- a/evojax/task/few_shot_omniglot.py +++ b/evojax/task/few_shot_omniglot.py @@ -33,6 +33,8 @@ class State(TaskState): obs: jnp.ndarray labels: jnp.ndarray + test_obs: jnp.ndarray + test_labels: jnp.ndarray def loss(prediction: jnp.ndarray, labels: jnp.ndarray) -> jnp.float32: # target = jax.nn.one_hot(target, ways) @@ -40,7 +42,7 @@ def loss(prediction: jnp.ndarray, labels: jnp.ndarray) -> jnp.float32: return -jnp.take_along_axis(prediction, labels, axis=-1).mean() def accuracy(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.float32: - predicted_class = jnp.argmax(prediction, axis=1) + predicted_class = jnp.argmax(prediction, axis=-1, keepdims=True) return jnp.mean(predicted_class == target) @@ -76,11 +78,12 @@ def __init__(self, iterator = iter(dataloader) def reset_fn(key): - try: - batch = iterator.next() - except StopIteration: - iterator = iter(dataloader) - batch = iterator.next() + # try: + # batch = iterator.next() + # except StopIteration: + # iterator = iter(dataloader) + # batch = iterator.next() + batch = iterator.next() # train_inputs shape: (batch_size, ways x shots, 1, 28, 28) # train_labels shape: (batch_size, ways x shots) train_inputs, train_labels = batch['train'] @@ -90,7 +93,7 @@ def reset_fn(key): test_inputs = jnp.transpose(jnp.array(test_inputs), (0,1,3,4,2)) test_labels = jnp.array(test_labels.unsqueeze(-1)) # the shape of this State is meant for ___one___ of the members of the population - return State(obs=(train_inputs, train_labels), test_inputs=test_inputs, test_labels=test_labels) + return State(obs=train_inputs, labels=train_labels, test_obs=test_inputs, test_labels=test_labels) # this vmap is for pop_size (so all members see the same state) self._reset_fn = jax.jit(jax.vmap(reset_fn)) diff --git a/examples/train_fast_learner.py b/examples/train_fast_learner.py index eebf2c1b..4f98fa30 100644 --- a/examples/train_fast_learner.py +++ b/examples/train_fast_learner.py @@ -35,7 +35,7 @@ def parse_args(): parser.add_argument( '--batch-size', type=int, default=16, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)') parser.add_argument( - '--max-iter', type=int, default=5000, help='Max training iterations.') + '--max-iter', type=int, default=10000, help='Max training iterations.') parser.add_argument( '--test-interval', type=int, default=1000, help='Test interval.') parser.add_argument( From 11e8993e8fd9fb864625f237b179fcaa561ac035 Mon Sep 17 00:00:00 2001 From: Rujikorn Charakorn Date: Mon, 11 Sep 2023 11:14:25 +0200 Subject: [PATCH 3/6] multi step sgd --- evojax/policy/fast_learner.py | 24 ++++++++++++++++-------- examples/train_fast_learner.py | 5 ++++- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/evojax/policy/fast_learner.py b/evojax/policy/fast_learner.py index 9992ae96..92e14360 100644 --- a/evojax/policy/fast_learner.py +++ b/evojax/policy/fast_learner.py @@ -31,7 +31,7 @@ def __call__(self, x): x = nn.log_softmax(x) return x -def update_params(apply_fn, params, x, y): +def get_update_params_fn(apply_fn): # this function could be something much more complicated other than vanilla sgd... # params has shape [pop_size, ...] # x has shape [pop_size, batch_size, ways * shots, 28, 28 ,1] @@ -46,9 +46,12 @@ def update_one_set(params, x, y): grad = jax.grad(loss_fn)(params, x, y) new_params = tree_map(lambda p,g : p - lr * g, params, grad) return new_params - update_multi_set = jax.vmap(update_one_set, in_axes=[0,0,0], out_axes=0) - update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[None,1,1], out_axes=1) - return update_multi_params_multi_set(params, x, y) + + update_multi_set = jax.vmap(update_one_set, in_axes=[0,0,0], out_axes=0) # maps over pop_size + update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[None,1,1], out_axes=1) # maps over batch_size + # return update_multi_params_multi_set(params, x, y) + subsequent_update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[1,1,1], out_axes=1) + return update_multi_params_multi_set, subsequent_update_multi_params_multi_set class FastLearner(PolicyNetwork): @@ -56,13 +59,14 @@ class FastLearner(PolicyNetwork): def __init__(self, n_classes: int, + num_grad_steps: int, logger: logging.Logger = None, ): if logger is None: self._logger = create_logger('FastLearner') else: self._logger = logger - + assert self.num_grad_steps > 0 model = CNN() params = model.init(random.PRNGKey(0), jnp.zeros([1, 28, 28, 1])) self.num_params, format_params_fn = get_params_format_fn(params) @@ -70,8 +74,8 @@ def __init__(self, 'FastLearner.num_params = {}'.format(self.num_params)) self._format_params_fn = jax.vmap(format_params_fn) # this maps over members self._model_apply = model.apply - self._forward_fn = jax.vmap(model.apply) # this maps over members - + # self._forward_fn = jax.vmap(model.apply) # this maps over members + self._update_fn, self._subsequent_update_fn = get_update_params_fn(self._model_apply) # lr = 1e-3 # _loss_fn = partial(loss_fn, self._forward_fn, n_classes) # def update_params_one_set(params, x, y): @@ -91,8 +95,12 @@ def get_actions(self, # x shape: [pop_size, batch_size, ways x shots, 28, 28, 1] # y shape: [pop_size, batch_size, ways x shots] # params shape: [pop_size, ...] - updated_params = update_params(self._model_apply, params, x, y) + # updated_params = update_params(self._model_apply, params, x, y) + updated_params = self._update_fn(params, x, y) # updated_params shape: [pop_size, batch_size, ...] + for _ in range(self.num_grad_step-1): + updated_params = self._subsequent_update_fn(updated_params, x, y) + apply_batched_params = jax.vmap(jax.vmap(self._model_apply)) # the output prediction should have shape: [pop_size, batch_size, ways * shots, n_classes] return apply_batched_params(updated_params, t_states.test_obs), p_states diff --git a/examples/train_fast_learner.py b/examples/train_fast_learner.py index 4f98fa30..9264e301 100644 --- a/examples/train_fast_learner.py +++ b/examples/train_fast_learner.py @@ -34,6 +34,8 @@ def parse_args(): '--pop-size', type=int, default=64, help='NE population size.') parser.add_argument( '--batch-size', type=int, default=16, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)') + parser.add_argument( + '--num-grad-steps', type=int, default=4, help='# of gradient steps') parser.add_argument( '--max-iter', type=int, default=10000, help='Max training iterations.') parser.add_argument( @@ -66,7 +68,8 @@ def main(config): logger.info('=' * 30) ways = n_classes = 5 - policy = FastLearner(n_classes=n_classes, logger=logger) + + policy = FastLearner(n_classes=n_classes, num_grad_steps=config.num_grad_steps, logger=logger) train_task = Omniglot(batch_size=config.batch_size, test=False) test_task = Omniglot(batch_size=config.batch_size, test=True) solver = PGPE( From c62570240f814245b126e7812bd81f1943e7597a Mon Sep 17 00:00:00 2001 From: 51616 Date: Mon, 11 Sep 2023 20:02:19 +0700 Subject: [PATCH 4/6] multi-step inner sgd updates --- evojax/policy/fast_learner.py | 17 +++++++++++++---- examples/train_fast_learner.py | 8 ++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/evojax/policy/fast_learner.py b/evojax/policy/fast_learner.py index 92e14360..e13110c1 100644 --- a/evojax/policy/fast_learner.py +++ b/evojax/policy/fast_learner.py @@ -20,12 +20,20 @@ class CNN(nn.Module): @nn.compact def __call__(self, x): - x = nn.Conv(features=8, kernel_size=(5, 5), padding='SAME')(x) + x = nn.Conv(features=8, kernel_size=(3, 3), padding='SAME')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=16, kernel_size=(5, 5), padding='SAME')(x) + x = nn.Conv(features=16, kernel_size=(3, 3), padding='SAME')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + + x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((*x.shape[:-3], -1)) # flatten x = nn.Dense(features=5)(x) x = nn.log_softmax(x) @@ -66,7 +74,8 @@ def __init__(self, self._logger = create_logger('FastLearner') else: self._logger = logger - assert self.num_grad_steps > 0 + assert num_grad_steps > 0 + self.num_grad_steps = num_grad_steps model = CNN() params = model.init(random.PRNGKey(0), jnp.zeros([1, 28, 28, 1])) self.num_params, format_params_fn = get_params_format_fn(params) @@ -98,7 +107,7 @@ def get_actions(self, # updated_params = update_params(self._model_apply, params, x, y) updated_params = self._update_fn(params, x, y) # updated_params shape: [pop_size, batch_size, ...] - for _ in range(self.num_grad_step-1): + for _ in range(self.num_grad_steps-1): updated_params = self._subsequent_update_fn(updated_params, x, y) apply_batched_params = jax.vmap(jax.vmap(self._model_apply)) diff --git a/examples/train_fast_learner.py b/examples/train_fast_learner.py index 9264e301..333af835 100644 --- a/examples/train_fast_learner.py +++ b/examples/train_fast_learner.py @@ -31,11 +31,11 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - '--pop-size', type=int, default=64, help='NE population size.') + '--pop-size', type=int, default=32, help='NE population size.') parser.add_argument( - '--batch-size', type=int, default=16, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)') + '--batch-size', type=int, default=32, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)') parser.add_argument( - '--num-grad-steps', type=int, default=4, help='# of gradient steps') + '--num-grad-steps', type=int, default=3, help='# of gradient steps') parser.add_argument( '--max-iter', type=int, default=10000, help='Max training iterations.') parser.add_argument( @@ -45,7 +45,7 @@ def parse_args(): parser.add_argument( '--seed', type=int, default=42, help='Random seed for training.') parser.add_argument( - '--center-lr', type=float, default=0.006, help='Center learning rate.') + '--center-lr', type=float, default=0.003, help='Center learning rate.') parser.add_argument( '--std-lr', type=float, default=0.089, help='Std learning rate.') parser.add_argument( From ad12fe6022a2ca53cb9562a3237daa223b085925 Mon Sep 17 00:00:00 2001 From: Rujikorn Charakorn Date: Mon, 11 Sep 2023 15:51:38 +0200 Subject: [PATCH 5/6] pre-compute batch norm of the input --- evojax/task/few_shot_omniglot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/evojax/task/few_shot_omniglot.py b/evojax/task/few_shot_omniglot.py index a373d09d..8dfd8292 100644 --- a/evojax/task/few_shot_omniglot.py +++ b/evojax/task/few_shot_omniglot.py @@ -88,9 +88,11 @@ def reset_fn(key): # train_labels shape: (batch_size, ways x shots) train_inputs, train_labels = batch['train'] train_inputs = jnp.transpose(jnp.array(train_inputs), (0,1,3,4,2)) + train_inputs = (train_inputs - train_inputs.mean(axis=1, keepdims=True)) / (train_inputs.std(axis=1, keepdims=True) + 1e-8) train_labels = jnp.array(train_labels.unsqueeze(-1)) test_inputs, test_labels = batch['test'] test_inputs = jnp.transpose(jnp.array(test_inputs), (0,1,3,4,2)) + test_inputs = (test_inputs - test_inputs.mean(axis=1, keepdims=True)) / (test_inputs.std(axis=1, keepdims=True) + 1e-8) test_labels = jnp.array(test_labels.unsqueeze(-1)) # the shape of this State is meant for ___one___ of the members of the population return State(obs=train_inputs, labels=train_labels, test_obs=test_inputs, test_labels=test_labels) From 8689cd5fb53efe1fe14b7a1f3452ba2e0b1a2621 Mon Sep 17 00:00:00 2001 From: 51616 Date: Tue, 12 Sep 2023 16:11:44 +0700 Subject: [PATCH 6/6] reach acc=86% --- evojax/policy/fast_learner.py | 1 + examples/train_fast_learner.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/evojax/policy/fast_learner.py b/evojax/policy/fast_learner.py index e13110c1..8510508e 100644 --- a/evojax/policy/fast_learner.py +++ b/evojax/policy/fast_learner.py @@ -35,6 +35,7 @@ def __call__(self, x): x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((*x.shape[:-3], -1)) # flatten + x = nn.Dense(features=32)(x) x = nn.Dense(features=5)(x) x = nn.log_softmax(x) return x diff --git a/examples/train_fast_learner.py b/examples/train_fast_learner.py index 333af835..ab789a33 100644 --- a/examples/train_fast_learner.py +++ b/examples/train_fast_learner.py @@ -37,7 +37,7 @@ def parse_args(): parser.add_argument( '--num-grad-steps', type=int, default=3, help='# of gradient steps') parser.add_argument( - '--max-iter', type=int, default=10000, help='Max training iterations.') + '--max-iter', type=int, default=20000, help='Max training iterations.') parser.add_argument( '--test-interval', type=int, default=1000, help='Test interval.') parser.add_argument(