Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions evojax/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .mlp_pi import PermutationInvariantPolicy
from .convnet import ConvNetPolicy
from .seq2seq import Seq2seqPolicy

from .fast_learner import FastLearner

__all__ = ['PolicyNetwork', 'MLPPolicy', 'PermutationInvariantPolicy',
'ConvNetPolicy', 'Seq2seqPolicy']
'ConvNetPolicy', 'Seq2seqPolicy', 'FastLearner']
116 changes: 116 additions & 0 deletions evojax/policy/fast_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@

import logging
from typing import Tuple

import jax
import jax.numpy as jnp
from jax import random
from jax.tree_util import tree_map
from flax import linen as nn

from evojax.policy.base import PolicyNetwork
from evojax.policy.base import PolicyState
from evojax.task.base import TaskState
from evojax.util import create_logger
from evojax.util import get_params_format_fn


class CNN(nn.Module):
"""CNN for MNIST."""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=8, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=16, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

x = x.reshape((*x.shape[:-3], -1)) # flatten
x = nn.Dense(features=32)(x)
x = nn.Dense(features=5)(x)
x = nn.log_softmax(x)
return x

def get_update_params_fn(apply_fn):
# this function could be something much more complicated other than vanilla sgd...
# params has shape [pop_size, ...]
# x has shape [pop_size, batch_size, ways * shots, 28, 28 ,1]
# y has shape [pop_size, batch_size, ways * shots, 1]
lr = 0.4
# loss for one param and one set
def loss_fn(params, x, y):
pred = apply_fn(params, x) # [ways * shots, n_classes]
return -jnp.take_along_axis(pred, y, axis=-1).mean()

def update_one_set(params, x, y):
grad = jax.grad(loss_fn)(params, x, y)
new_params = tree_map(lambda p,g : p - lr * g, params, grad)
return new_params

update_multi_set = jax.vmap(update_one_set, in_axes=[0,0,0], out_axes=0) # maps over pop_size
update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[None,1,1], out_axes=1) # maps over batch_size
# return update_multi_params_multi_set(params, x, y)
subsequent_update_multi_params_multi_set = jax.vmap(update_multi_set,in_axes=[1,1,1], out_axes=1)
return update_multi_params_multi_set, subsequent_update_multi_params_multi_set


class FastLearner(PolicyNetwork):
"""A convolutional neural network for the MNIST classification task."""

def __init__(self,
n_classes: int,
num_grad_steps: int,
logger: logging.Logger = None,
):
if logger is None:
self._logger = create_logger('FastLearner')
else:
self._logger = logger
assert num_grad_steps > 0
self.num_grad_steps = num_grad_steps
model = CNN()
params = model.init(random.PRNGKey(0), jnp.zeros([1, 28, 28, 1]))
self.num_params, format_params_fn = get_params_format_fn(params)
self._logger.info(
'FastLearner.num_params = {}'.format(self.num_params))
self._format_params_fn = jax.vmap(format_params_fn) # this maps over members
self._model_apply = model.apply
# self._forward_fn = jax.vmap(model.apply) # this maps over members
self._update_fn, self._subsequent_update_fn = get_update_params_fn(self._model_apply)
# lr = 1e-3
# _loss_fn = partial(loss_fn, self._forward_fn, n_classes)
# def update_params_one_set(params, x, y):
# grads = jax.grad(_loss_fn)(params, x, y)
# new_params = params - lr * grads
# return new_params

# update_params_multi_set = jax.vmap(update_params_one_set, in_axes=[None,0,0], out_axes=[0])
# update_params_pop_multi_set = jax.vmap(update_params_multi_set, in_axes=[0,None,None], out_axes=[0])

def get_actions(self,
t_states: TaskState,
params: jnp.ndarray,
p_states: PolicyState) -> Tuple[jnp.ndarray, PolicyState]:
params = self._format_params_fn(params)
x, y = t_states.obs, t_states.labels
# x shape: [pop_size, batch_size, ways x shots, 28, 28, 1]
# y shape: [pop_size, batch_size, ways x shots]
# params shape: [pop_size, ...]
# updated_params = update_params(self._model_apply, params, x, y)
updated_params = self._update_fn(params, x, y)
# updated_params shape: [pop_size, batch_size, ...]
for _ in range(self.num_grad_steps-1):
updated_params = self._subsequent_update_fn(updated_params, x, y)

apply_batched_params = jax.vmap(jax.vmap(self._model_apply))
# the output prediction should have shape: [pop_size, batch_size, ways * shots, n_classes]
return apply_batched_params(updated_params, t_states.test_obs), p_states
122 changes: 122 additions & 0 deletions evojax/task/few_shot_omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2022 The EvoJAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import numpy as np
from typing import Tuple

import jax
import jax.numpy as jnp
from jax import random
from flax.struct import dataclass

from evojax.task.base import VectorizedTask
from evojax.task.base import TaskState



### step_fn and reset_fn defined here are meant for a ___single___ policy network
### which will then be mapped to all the members using vmap
### such that all members take as input the same observation
@dataclass
class State(TaskState):
obs: jnp.ndarray
labels: jnp.ndarray
test_obs: jnp.ndarray
test_labels: jnp.ndarray

def loss(prediction: jnp.ndarray, labels: jnp.ndarray) -> jnp.float32:
# target = jax.nn.one_hot(target, ways)
# return -jnp.mean(jnp.sum(prediction * target, axis=1))
return -jnp.take_along_axis(prediction, labels, axis=-1).mean()

def accuracy(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.float32:
predicted_class = jnp.argmax(prediction, axis=-1, keepdims=True)
return jnp.mean(predicted_class == target)


class Omniglot(VectorizedTask):
"""Omniglot few-show learning classification task."""

def __init__(self,
batch_size: int = 16,
test: bool = False):
num_workers = 4
test_shots = 15
shots = 5
self.ways = 5
self.max_steps = 1
self.obs_shape = tuple([28, 28, 1])
self.act_shape = tuple([self.ways, ])


# Delayed importing of torchmeta
try:
# TODO: torchmeta seems to be unmaintained and does not support newer versions of torch
# maybe replace this with processing code from https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
except ModuleNotFoundError:
print('You need to install torchmeta for this task.')
print(' pip install torchmeta')
sys.exit(1)

dataset = omniglot("data", ways=self.ways, shots=shots, test_shots=test_shots,
shuffle=not test, meta_train=not test, meta_test=test, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
iterator = iter(dataloader)

def reset_fn(key):
# try:
# batch = iterator.next()
# except StopIteration:
# iterator = iter(dataloader)
# batch = iterator.next()
batch = iterator.next()
# train_inputs shape: (batch_size, ways x shots, 1, 28, 28)
# train_labels shape: (batch_size, ways x shots)
train_inputs, train_labels = batch['train']
train_inputs = jnp.transpose(jnp.array(train_inputs), (0,1,3,4,2))
train_inputs = (train_inputs - train_inputs.mean(axis=1, keepdims=True)) / (train_inputs.std(axis=1, keepdims=True) + 1e-8)
train_labels = jnp.array(train_labels.unsqueeze(-1))
test_inputs, test_labels = batch['test']
test_inputs = jnp.transpose(jnp.array(test_inputs), (0,1,3,4,2))
test_inputs = (test_inputs - test_inputs.mean(axis=1, keepdims=True)) / (test_inputs.std(axis=1, keepdims=True) + 1e-8)
test_labels = jnp.array(test_labels.unsqueeze(-1))
# the shape of this State is meant for ___one___ of the members of the population
return State(obs=train_inputs, labels=train_labels, test_obs=test_inputs, test_labels=test_labels)

# this vmap is for pop_size (so all members see the same state)
self._reset_fn = jax.jit(jax.vmap(reset_fn))

def step_fn(state, action):
# state: state returned by reset_fn
# action: predictions from ___one___ of the members in the population
# should have the shape of [batch_size, ways x shots, ways (n_classes)]
if test:
reward = accuracy(action, state.test_labels)
else:
reward = -loss(action, state.test_labels)
return state, reward, jnp.ones(())

# this vmap is for pop_size (so all members output actions given the same state)
self._step_fn = jax.jit(jax.vmap(step_fn))

def reset(self, key: jnp.ndarray) -> State:
return self._reset_fn(key)

def step(self,
state: TaskState,
action: jnp.ndarray) -> Tuple[TaskState, jnp.ndarray, jnp.ndarray]:
return self._step_fn(state, action)
115 changes: 115 additions & 0 deletions examples/train_fast_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2022 The EvoJAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Train an agent for MNIST classification.

Example command to run this script: `python train_mnist.py --gpu-id=0`
"""

import argparse
import os
import shutil

from evojax import Trainer
from evojax.task.few_shot_omniglot import Omniglot
from evojax.policy import FastLearner
from evojax.algo import PGPE
from evojax import util


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--pop-size', type=int, default=32, help='NE population size.')
parser.add_argument(
'--batch-size', type=int, default=32, help='Batch size for training. (# of sets in few-shot learning, each member sees the same data)')
parser.add_argument(
'--num-grad-steps', type=int, default=3, help='# of gradient steps')
parser.add_argument(
'--max-iter', type=int, default=20000, help='Max training iterations.')
parser.add_argument(
'--test-interval', type=int, default=1000, help='Test interval.')
parser.add_argument(
'--log-interval', type=int, default=100, help='Logging interval.')
parser.add_argument(
'--seed', type=int, default=42, help='Random seed for training.')
parser.add_argument(
'--center-lr', type=float, default=0.003, help='Center learning rate.')
parser.add_argument(
'--std-lr', type=float, default=0.089, help='Std learning rate.')
parser.add_argument(
'--init-std', type=float, default=0.039, help='Initial std.')
parser.add_argument(
'--gpu-id', type=str, help='GPU(s) to use.')
parser.add_argument(
'--debug', action='store_true', help='Debug mode.')
config, _ = parser.parse_known_args()
return config


def main(config):
log_dir = './log/omniglot'
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
logger = util.create_logger(
name='Omniglot', log_dir=log_dir, debug=config.debug)
logger.info('EvoJAX Omniglot Demo')
logger.info('=' * 30)

ways = n_classes = 5

policy = FastLearner(n_classes=n_classes, num_grad_steps=config.num_grad_steps, logger=logger)
train_task = Omniglot(batch_size=config.batch_size, test=False)
test_task = Omniglot(batch_size=config.batch_size, test=True)
solver = PGPE(
pop_size=config.pop_size,
param_size=policy.num_params,
optimizer='adam',
center_learning_rate=config.center_lr,
stdev_learning_rate=config.std_lr,
init_stdev=config.init_std,
logger=logger,
seed=config.seed,
)

# Train.
trainer = Trainer(
policy=policy,
solver=solver,
train_task=train_task,
test_task=test_task,
max_iter=config.max_iter,
log_interval=config.log_interval,
test_interval=config.test_interval,
n_repeats=1,
n_evaluations=1,
seed=config.seed,
log_dir=log_dir,
logger=logger,
)
trainer.run(demo_mode=False)

# Test the final model.
src_file = os.path.join(log_dir, 'best.npz')
tar_file = os.path.join(log_dir, 'model.npz')
shutil.copy(src_file, tar_file)
trainer.model_dir = log_dir
trainer.run(demo_mode=True)


if __name__ == '__main__':
configs = parse_args()
if configs.gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id
main(configs)