diff --git a/requirements.txt b/requirements.txt index 369306da..f02aaf9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ torchsummary~=1.5.1 #webdataset introduces breaking changes in 0.1.49, so setting this to an exact equality webdataset==0.1.40 tqdm~=4.48.0 +procgen==0.10.4 # Jupyter Lab is used for our experiment analysis notebook jupyterlab~=2.2.6 diff --git a/src/il_representations/algos/__init__.py b/src/il_representations/algos/__init__.py index 5f29bff0..4283a685 100644 --- a/src/il_representations/algos/__init__.py +++ b/src/il_representations/algos/__init__.py @@ -1,7 +1,7 @@ from il_representations.algos.representation_learner import RepresentationLearner, DEFAULT_HARDCODED_PARAMS from il_representations.algos.encoders import MomentumEncoder, InverseDynamicsEncoder, TargetStoringActionEncoder, \ RecurrentEncoder, BaseEncoder, VAEEncoder, ActionEncodingEncoder, ActionEncodingInverseDynamicsEncoder, \ - infer_action_shape_info + infer_action_shape_info, SimCLRModel from il_representations.algos.decoders import NoOp, MomentumProjectionHead, \ BYOLProjectionHead, ActionConditionedVectorDecoder, ContrastiveInverseDynamicsConcatenationHead, \ ActionPredictionHead, PixelDecoder, SymmetricProjectionHead, AsymmetricProjectionHead diff --git a/src/il_representations/algos/augmenters.py b/src/il_representations/algos/augmenters.py index 24b0ebda..18da951b 100644 --- a/src/il_representations/algos/augmenters.py +++ b/src/il_representations/algos/augmenters.py @@ -11,12 +11,18 @@ either augment just the context, or both the context and the target, depending on the algorithm. """ +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + class Augmenter(ABC): - def __init__(self, augmenter_spec, color_space): - augment_op = StandardAugmentations.from_string_spec( - augmenter_spec, color_space) - self.augment_op = augment_op + def __init__(self, augmenter_spec, color_space, augment_func=None): + self.augment_func = augment_func + if augment_func: + self.augment_op = augment_func + else: + augment_op = StandardAugmentations.from_string_spec( + augmenter_spec, color_space) + self.augment_op = augment_op @abstractmethod def __call__(self, contexts, targets): @@ -33,6 +39,21 @@ def __call__(self, contexts, targets): class AugmentContextAndTarget(Augmenter): def __call__(self, contexts, targets): + pil_process_func = transforms.Compose([ + transforms.ToPILImage() + ]) + if self.augment_func: + context_ret, target_ret = [], [] + for context, target in zip(contexts, targets): + if isinstance(context, torch.Tensor) and \ + isinstance(self.augment_op.transforms[0], + transforms.RandomResizedCrop): + context, target = pil_process_func(context.cpu()), \ + pil_process_func(target.cpu()) + context_ret.append(self.augment_op(context)) + target_ret.append(self.augment_op(target)) + return torch.stack(context_ret, dim=0).to(device), \ + torch.stack(target_ret, dim=0).to(device) return self.augment_op(contexts), self.augment_op(targets) diff --git a/src/il_representations/algos/decoders.py b/src/il_representations/algos/decoders.py index 3319adfd..087e0336 100644 --- a/src/il_representations/algos/decoders.py +++ b/src/il_representations/algos/decoders.py @@ -64,8 +64,8 @@ def get_sequential_from_architecture(architecture, representation_dim, projectio input_dim = representation_dim for layer_def in architecture: layers.append(nn.Linear(input_dim, layer_def['output_dim'])) - layers.append(nn.ReLU()) layers.append(nn.BatchNorm1d(num_features=layer_def['output_dim'])) + layers.append(nn.ReLU(inplace=True)) input_dim = layer_def['output_dim'] layers.append(nn.Linear(input_dim, projection_dim)) return nn.Sequential(*layers) @@ -131,7 +131,7 @@ def _apply_projection_layer(self, z_dist, mean_layer, stdev_layer): # We better not have had a learned standard deviation in # the encoder, since there's no clear way on how to pass # it forward - assert np.all((z_dist.stddev == 1).numpy()) + assert np.all((z_dist.stddev == 1).cpu().numpy()) stddev = self.ones_like_projection_dim(mean) else: stddev = stdev_layer(z_vector) diff --git a/src/il_representations/algos/encoders.py b/src/il_representations/algos/encoders.py index 1fb5091a..f927eaf1 100644 --- a/src/il_representations/algos/encoders.py +++ b/src/il_representations/algos/encoders.py @@ -10,6 +10,8 @@ from torchvision.models.resnet import BasicBlock as BasicResidualBlock import torch from torch import nn +from torchvision.models.resnet import resnet50, resnet34 +import torch.nn.functional as F from pyro.distributions import Delta from gym import spaces @@ -197,8 +199,10 @@ def __init__(self, use_sn=False, arch_str='MAGICALCNN-resnet-128', ActivationCls=torch.nn.ReLU): + super().__init__() + # If block_type == resnet, use ResNet's basic block. # If block_type == magical, use MAGICAL block from its paper. assert arch_str in NETWORK_ARCHITECTURE_DEFINITIONS.keys() @@ -265,11 +269,35 @@ def forward(self, x): warn_on_non_image_tensor(x) return self.shared_network(x) + +class SimCLRModel(nn.Module): + def __init__(self, observation_space, representation_dim=128): + super(SimCLRModel, self).__init__() + + self.f = [] + in_channel = observation_space.shape[0] + for name, module in resnet34().named_children(): + if name == 'conv1': + module = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False) + if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): + self.f.append(module) + # encoder + # Temporarily add an extra layer to be closer to our model implementation + self.f = nn.Sequential(*self.f) + + + def forward(self, x): + x = self.f(x) + feature = torch.flatten(x, start_dim=1) + return F.normalize(feature, dim=-1) + + # string names for convolutional networks; this makes it easier to choose # between them from the command line NETWORK_SHORT_NAMES = { 'BasicCNN': BasicCNN, 'MAGICALCNN': MAGICALCNN, + 'SimCLRModel': SimCLRModel } @@ -348,22 +376,22 @@ class BaseEncoder(Encoder): def __init__(self, obs_space, representation_dim, obs_encoder_cls=None, learn_scale=False, latent_dim=None, scale_constant=1, obs_encoder_cls_kwargs=None): """ - :param obs_space: The observation space that this Encoder will be used on - :param representation_dim: The number of dimensions of the representation - that will be learned - :param obs_encoder_cls: An internal architecture implementing `forward` - to return a single vector representing the mean representation z - of a fixed-variance representation distribution (in the deterministic - case), or a latent dimension, in the stochastic case. This is - expected NOT to end in a ReLU (i.e. final layer should be linear). - :param learn_scale: A flag for whether we want to learn a parametrized - standard deviation. If this is set to False, a constant value of - will be returned as the standard deviation - :param latent_dim: Dimension of the latents that feed into mean and std networks - If not set, this defaults to representation_dim * 2. - :param scale_constant: The constant value that will be returned if learn_scale is - set to False. - :param obs_encoder_cls_kwargs: kwargs the encoder class will take. + :param obs_space: The observation space that this Encoder will be used on + :param representation_dim: The number of dimensions of the representation + that will be learned + :param obs_encoder_cls: An internal architecture implementing `forward` + to return a single vector representing the mean representation z + of a fixed-variance representation distribution (in the deterministic + case), or a latent dimension, in the stochastic case. This is + expected NOT to end in a ReLU (i.e. final layer should be linear). + :param learn_scale: A flag for whether we want to learn a parametrized + standard deviation. If this is set to False, a constant value of + will be returned as the standard deviation + :param latent_dim: Dimension of the latents that feed into mean and std networks + If not set, this defaults to representation_dim * 2. + :param scale_constant: The constant value that will be returned if learn_scale is + set to False. + :param obs_encoder_cls_kwargs: kwargs the encoder class will take. """ super().__init__() if obs_encoder_cls_kwargs is None: @@ -380,6 +408,13 @@ def __init__(self, obs_space, representation_dim, obs_encoder_cls=None, self.network = obs_encoder_cls(obs_space, representation_dim, **obs_encoder_cls_kwargs) self.scale_constant = scale_constant + if torch.cuda.device_count() > 1: + print("Using", torch.cuda.device_count(), "GPUs!") + self.network = nn.DataParallel(self.network) + + self.network.to(self.device) + + def forward(self, x, traj_info): if self.learn_scale: return self.forward_with_stddev(x, traj_info) diff --git a/src/il_representations/algos/losses.py b/src/il_representations/algos/losses.py index b999860d..c5014f12 100644 --- a/src/il_representations/algos/losses.py +++ b/src/il_representations/algos/losses.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod import torch +import numpy as np import torch.nn.functional as F import stable_baselines3.common.logger as sb_logger from pyro.distributions import Delta @@ -161,11 +162,12 @@ class SymmetricContrastiveLoss(RepresentationLoss): all similarities with J, and also all similarities with I, and calculates cross-entropy on both """ - def __init__(self, device, sample=False, temp=0.1, normalize=True): + def __init__(self, device, sample=False, temp=0.1, normalize=True, use_repo_loss=False): super(SymmetricContrastiveLoss, self).__init__(device, sample) self.criterion = torch.nn.CrossEntropyLoss() self.temp = temp + self.use_repo_loss = use_repo_loss # Most methods use either cosine similarity or matrix multiplication similarity. Since cosine similarity equals # taking MatMul on normalized vectors, setting normalize=True is equivalent to using torch.CosineSimilarity(). @@ -180,50 +182,74 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None) # decoded_context -> representation of context + optional projection head # target -> representation of target + optional projection head # encoded_context -> not used by this loss + decoded_contexts, targets = self.get_vector_forms(decoded_context_dist, target_dist) z_i = decoded_contexts z_j = targets batch_size = z_i.shape[0] - if self.normalize: # Use cosine similarity + + if self.use_repo_loss: + # Normalize to avoid infinities z_i = F.normalize(z_i, dim=1) z_j = F.normalize(z_j, dim=1) + out = torch.cat([z_i, z_j], dim=0) + # [2*B, 2*B] + sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / self.temp) + mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool() + # [2*B, 2*B-1] + sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1) + + # compute loss + pos_sim = torch.exp(torch.sum(z_i * z_j, dim=-1) / self.temp) + # [2*B] + pos_sim = torch.cat([pos_sim, pos_sim], dim=0) + loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() + if torch.isnan(loss): + breakpoint() + return loss + else: + if not self.normalize: + breakpoint() + if self.normalize: # Use cosine similarity + z_i = F.normalize(z_i, dim=1) + z_j = F.normalize(z_j, dim=1) - mask = (torch.eye(batch_size) * self.large_num).to(self.device) - - # Similarity of the original images with all other original images in current batch. Return a matrix of NxN. - logits_aa = torch.matmul(z_i, z_i.T) # NxN - - # Values on the diagonal line are each image's similarity with itself - logits_aa = logits_aa - mask - # Similarity of the augmented images with all other augmented images. - logits_bb = torch.matmul(z_j, z_j.T) # NxN - logits_bb = logits_bb - mask - # Similarity of original images and augmented images - logits_ab = torch.matmul(z_i, z_j.T) # NxN - logits_ba = torch.matmul(z_j, z_i.T) # NxN - avg_self_similarity = logits_ab.diag().mean().item() - logits_other_sim_mask = ~torch.eye(batch_size, dtype=bool, device=logits_ab.device) - avg_other_similarity = logits_ab.masked_select(logits_other_sim_mask).mean().item() + mask = (torch.eye(batch_size) * self.large_num).to(self.device) - sb_logger.record('avg_self_similarity', avg_self_similarity) - sb_logger.record('avg_other_similarity', avg_other_similarity) - sb_logger.record('self_other_sim_delta', avg_self_similarity - avg_other_similarity) + # Similarity of the original images with all other original images in current batch. Return a matrix of NxN. + logits_aa = torch.matmul(z_i, z_i.T) # NxN - # Each row now contains an image's similarity with the batch's augmented images & original images. This applies - # to both original and augmented images (hence "symmetric"). - logits_i = torch.cat((logits_ab, logits_aa), 1) # Nx2N - logits_j = torch.cat((logits_ba, logits_bb), 1) # Nx2N - logits = torch.cat((logits_i, logits_j), axis=0) # 2Nx2N - logits /= self.temp + # Values on the diagonal line are each image's similarity with itself + logits_aa = logits_aa - mask + # Similarity of the augmented images with all other augmented images. + logits_bb = torch.matmul(z_j, z_j.T) # NxN + logits_bb = logits_bb - mask + # Similarity of original images and augmented images + logits_ab = torch.matmul(z_i, z_j.T) # NxN + logits_ba = torch.matmul(z_j, z_i.T) # NxN + + avg_self_similarity = logits_ab.diag().mean().item() + logits_other_sim_mask = ~torch.eye(batch_size, dtype=bool, device=logits_ab.device) + avg_other_similarity = logits_ab.masked_select(logits_other_sim_mask).mean().item() + sb_logger.record('avg_self_similarity', avg_self_similarity) + sb_logger.record('avg_other_similarity', avg_other_similarity) + sb_logger.record('self_other_sim_delta', avg_self_similarity - avg_other_similarity) + + # Each row now contains an image's similarity with the batch's augmented images & original images. This applies + # to both original and augmented images (hence "symmetric"). + logits_i = torch.cat((logits_ab, logits_aa), 1) # Nx2N + logits_j = torch.cat((logits_ba, logits_bb), 1) # Nx2N + logits = torch.cat((logits_i, logits_j), axis=0) # 2Nx2N + logits /= self.temp - # The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of - # represent(image_i) and represent(augmented_image_i). - label = torch.arange(batch_size, dtype=torch.long).to(self.device) - labels = torch.cat((label, label), axis=0) + # The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of + # represent(image_i) and represent(augmented_image_i). + label = torch.arange(batch_size, dtype=torch.long).to(self.device) + labels = torch.cat((label, label), axis=0) - return self.criterion(logits, labels) + return self.criterion(logits, labels) class NegativeLogLikelihood(RepresentationLoss): diff --git a/src/il_representations/algos/representation_learner.py b/src/il_representations/algos/representation_learner.py index 234ec6f7..4b75dea1 100644 --- a/src/il_representations/algos/representation_learner.py +++ b/src/il_representations/algos/representation_learner.py @@ -13,8 +13,37 @@ from il_representations.algos.base_learner import BaseEnvironmentLearner from il_representations.algos.batch_extenders import QueueBatchExtender +from il_representations.algos.encoders import warn_on_non_image_tensor from il_representations.algos.utils import AverageMeter, LinearWarmupCosine from il_representations.data.read_dataset import datasets_to_loader, SubdatasetExtractor +from il_representations.utils import save_rgb_tensor +from torch.utils.data import DataLoader + +from PIL import Image +from torchvision.datasets import CIFAR10 +import os +import numpy as np +from torchvision import transforms + +class CIFAR10Pair(CIFAR10): + """CIFAR10 Dataset. + """ + + def __getitem__(self, index): + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img) + id_val = np.random.randint(0, 50000) + #save_image(img, f'results/{id_val}_img_pre_trans.png') + if self.transform is not None: + pos_1 = self.transform(img) + pos_2 = self.transform(img) + # save_rgb_tensor(pos_1, f'results/{id_val}_pos1.png') + # save_rgb_tensor(pos_2, f'results/{id_val}_pos2.png') + if self.target_transform is not None: + target = self.target_transform(target) + + return pos_1, pos_2, target + DEFAULT_HARDCODED_PARAMS = [ 'encoder', 'decoder', 'loss_calculator', 'augmenter', @@ -57,6 +86,7 @@ def __init__(self, *, representation_dim=512, projection_dim=None, device=None, + normalize=True, shuffle_batches=True, shuffle_buffer_size=1024, batch_size=256, @@ -91,6 +121,7 @@ def __init__(self, *, os.makedirs(self.decoder_checkpoints_path, exist_ok=True) self.device = get_device("auto" if device is None else device) + self.normalize = normalize self.shuffle_batches = shuffle_batches self.shuffle_buffer_size = shuffle_buffer_size self.batch_size = batch_size @@ -106,6 +137,15 @@ def __init__(self, *, # This doesn't have any meaningful effect unless you specify a projection head. projection_dim = representation_dim + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(90), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor()]) + # transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) + + # augmenter_kwargs["augment_func"] = train_transform self.augmenter = augmenter(**augmenter_kwargs) self.target_pair_constructor = target_pair_constructor(**to_dict(target_pair_constructor_kwargs)) @@ -154,7 +194,8 @@ def _calculate_norms(self, norm_type=2): norm_type = float(norm_type) encoder_params, decoder_params = self._get_trainable_parameters() - trainable_params = encoder_params + decoder_params + # TODO undo + trainable_params = encoder_params # + decoder_params stacked_gradient_norms = torch.stack([torch.norm(p.grad.detach(), norm_type).to(self.device) for p in trainable_params]) stacked_weight_norms = torch.stack([torch.norm(p.detach(), norm_type).to(self.device) for p in trainable_params]) @@ -275,11 +316,17 @@ def learn(self, datasets, batches_per_epoch, n_epochs, n_trajs=None, callbacks=( self.encoder.train(True) self.decoder.train(True) + # for pname, pval in sorted(self.encoder.named_parameters()): + # print(f'{pname}: {pval.float().mean().item():.4g} pm {pval.float().std().item():.4g}, shape {pval.shape}') batches_trained = 0 logging.debug( f"Training for {n_epochs} epochs, each of {batches_per_epoch} " f"batches (batch size {self.batch_size})") + # TODO add transform back in, and probably comment out our augmenter line? + # train_data = CIFAR10Pair(root='data', train=True, transform=train_transform, download=True) + # train_loader = iter(DataLoader(train_data, batch_size=self.batch_size, shuffle=True, num_workers=16, pin_memory=True, + # drop_last=True)) for epoch_num in range(1, n_epochs + 1): loss_meter = AverageMeter() # Set encoder and decoder to be in training mode @@ -290,28 +337,44 @@ def learn(self, datasets, batches_per_epoch, n_epochs, n_trajs=None, callbacks=( for step, batch in enumerate(dataloader): # Construct batch (currently just using Torch's default batch-creator) contexts, targets, traj_ts_info, extra_context = self.unpack_batch(batch) + # contexts, targets, _ = train_loader.next() + # if step == 0: + # for i in range(10): + # breakpoint() + # save_rgb_tensor(contexts[i][:3], os.path.join(self.log_dir, 'saved_images', f'contexts_from_disk_{i}.png')) + # save_rgb_tensor(targets[i][:3], os.path.join(self.log_dir, 'saved_images', f'targets_from_disk_{i}.png')) # Use an algorithm-specific augmentation strategy to augment either # just context, or both context and targets contexts, targets = self._prep_tensors(contexts), self._prep_tensors(targets) extra_context = self._prep_tensors(extra_context) traj_ts_info = self._prep_tensors(traj_ts_info) # Note: preprocessing might be better to do on CPU if, in future, we can parallelize doing so + # TODO this may not make sense for CIFAR10, maybe double normalizing contexts = self._preprocess(contexts) if self.preprocess_target: targets = self._preprocess(targets) + if step == 0: + for i in range(10): + save_rgb_tensor(contexts[i][:3], os.path.join(self.log_dir, 'saved_images', f'contexts_pre_aug_{i}.png')) + save_rgb_tensor(targets[i][:3], os.path.join(self.log_dir, 'saved_images', f'targets_pre_aug_{i}.png')) + # TODO put back in when done with "swap their data in" test contexts, targets = self.augmenter(contexts, targets) + if step == 0: + for i in range(10): + save_rgb_tensor(contexts[i][:3], os.path.join(self.log_dir, 'saved_images', f'contexts_{i}.png')) + save_rgb_tensor(targets[i][:3], os.path.join(self.log_dir, 'saved_images', f'targets_{i}.png')) extra_context = self._preprocess_extra_context(extra_context) # This is typically a noop, but sometimes we also augment the extra context extra_context = self.augmenter.augment_extra_context(extra_context) - + warn_on_non_image_tensor(contexts) + warn_on_non_image_tensor(targets) # These will typically just use the forward() function for the encoder, but can optionally # use a specific encode_context and encode_target if one is implemented encoded_contexts = self.encoder.encode_context(contexts, traj_ts_info) encoded_targets = self.encoder.encode_target(targets, traj_ts_info) # Typically the identity function encoded_extra_context = self.encoder.encode_extra_context(extra_context, traj_ts_info) - # Use an algorithm-specific decoder to "decode" the representations into a loss-compatible tensor # As with encode, these will typically just use forward() decoded_contexts = self.decoder.decode_context(encoded_contexts, traj_ts_info, encoded_extra_context) @@ -324,6 +387,7 @@ def learn(self, datasets, batches_per_epoch, n_epochs, n_trajs=None, callbacks=( # Use an algorithm-specific loss function. Typically this only requires decoded_contexts and # decoded_targets, but VAE requires encoded_contexts, so we pass it in here + # loss = self.loss_calculator(encoded_contexts, encoded_targets, encoded_contexts) loss = self.loss_calculator(decoded_contexts, decoded_targets, encoded_contexts) if batches_trained % self.calc_log_interval == 0: loss_item = loss.item() diff --git a/src/il_representations/algos/utils.py b/src/il_representations/algos/utils.py index e40757dd..599daae2 100644 --- a/src/il_representations/algos/utils.py +++ b/src/il_representations/algos/utils.py @@ -99,26 +99,26 @@ def log(self, msg): class LinearWarmupCosine(_LRScheduler): - def __init__(self, optimizer, T_max, warmup_epoch=30, eta_min=0, last_epoch=-1): + def __init__(self, optimizer, T_max, total_epochs, warmup_epoch=30, eta_min=0, last_epoch=-1): self.T_max = T_max self.eta_min = eta_min self.warmup_epoch = warmup_epoch + self.cosine_epochs = total_epochs - warmup_epoch super(LinearWarmupCosine, self).__init__(optimizer, last_epoch) def get_lr(self): - if self.warmup_epoch > 0: - if self.last_epoch <= self.warmup_epoch: - return [base_lr / self.warmup_epoch * self.last_epoch for base_lr in self.base_lrs] - if ((self.last_epoch - self.warmup_epoch) - 1 - (self.T_max - self.warmup_epoch)) % (2 * (self.T_max - self.warmup_epoch)) == 0: - return [group['lr'] + (base_lr - self.eta_min) * - (1 - math.cos(math.pi / (self.T_max - self.warmup_epoch))) / 2 - for base_lr, group in - zip(self.base_lrs, self.optimizer.param_groups)] - else: - return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epoch) / (self.T_max - self.warmup_epoch))) / - (1 + math.cos(math.pi * ((self.last_epoch - self.warmup_epoch) - 1) / (self.T_max - self.warmup_epoch))) * - (group['lr'] - self.eta_min) + self.eta_min - for group in self.optimizer.param_groups] + # Linear scaling if we are in the warmup stage + use_linear_scaling = self.warmup_epoch > 0 and self.last_epoch < self.warmup_epoch + result = [] + for base_lr in self.base_lrs: + delta = base_lr - self.eta_min + if use_linear_scaling: + fraction = (self.last_epoch + 1) / self.warmup_epoch + else: + rescaled_epoch = self.last_epoch - self.warmup_epoch + fraction = 0.5 * (1 + math.cos(math.pi * rescaled_epoch / self.cosine_epochs)) + result.append(self.eta_min + fraction * delta) + return result def set_global_seeds(seed): diff --git a/src/il_representations/envs/auto.py b/src/il_representations/envs/auto.py index bd819426..47c46164 100644 --- a/src/il_representations/envs/auto.py +++ b/src/il_representations/envs/auto.py @@ -21,6 +21,9 @@ from il_representations.envs.minecraft_envs import (MinecraftVectorWrapper, get_env_name_minecraft, load_dataset_minecraft) +from il_representations.envs.cifar_envs import load_dataset_cifar, MockGymEnv +from il_representations.envs.procgen_envs import (load_dataset_procgen, + get_procgen_env_name) from il_representations.scripts.utils import update as dict_update ERROR_MESSAGE = "no support for benchmark_name={benchmark_name!r}" @@ -74,6 +77,10 @@ def load_dict_dataset(benchmark_name, n_traj=None): dataset_dict = load_dataset_atari(n_traj=n_traj) elif benchmark_name == 'minecraft': dataset_dict = load_dataset_minecraft(n_traj=n_traj) + elif benchmark_name == 'cifar-10': + dataset_dict = load_dataset_cifar() + elif benchmark_name == 'procgen': + dataset_dict = load_dataset_procgen(n_traj=n_traj) else: raise NotImplementedError(ERROR_MESSAGE.format(**locals())) @@ -100,6 +107,10 @@ def get_gym_env_name(benchmark_name, dm_control_full_env_names, task_name): return task_name elif benchmark_name == 'minecraft': return get_env_name_minecraft() # uses task_name implicitly through config param + elif benchmark_name == 'cifar-10': + return 'cifar-10-cls' + elif benchmark_name == 'procgen': + return get_procgen_env_name() raise NotImplementedError(ERROR_MESSAGE.format(**locals())) @@ -163,6 +174,17 @@ def load_vec_env(benchmark_name, dm_control_full_env_names, parallel=venv_parallel, wrapper_class=MinecraftVectorWrapper, max_episode_steps=minecraft_max_env_steps) + elif benchmark_name == 'cifar-10': + return MockGymEnv() + elif benchmark_name == 'procgen': + raw_procgen_env = make_vec_env(gym_env_name, + n_envs=n_envs, + parallel=venv_parallel, + parallel_workers=parallel_workers) + final_env = VecFrameStack(VecTransposeImage(raw_procgen_env), 4) + assert final_env.observation_space.shape == (12, 64, 64), \ + final_env.observation_space.shape + return final_env raise NotImplementedError(ERROR_MESSAGE.format(**locals())) @@ -266,7 +288,9 @@ def load_color_space(benchmark_name): 'magical': ColorSpace.RGB, 'dm_control': ColorSpace.RGB, 'atari': ColorSpace.GRAY, - 'minecraft': ColorSpace.RGB + 'minecraft': ColorSpace.RGB, + 'cifar-10': ColorSpace.RGB, + 'procgen': ColorSpace.RGB } try: return color_spaces[benchmark_name] diff --git a/src/il_representations/envs/cifar_envs.py b/src/il_representations/envs/cifar_envs.py new file mode 100644 index 00000000..82e0db1b --- /dev/null +++ b/src/il_representations/envs/cifar_envs.py @@ -0,0 +1,47 @@ +import torch +import numpy as np +import torchvision +import torchvision.transforms as transforms + +from imitation.augment.color import ColorSpace +from gym.spaces import Discrete, Box + + +def load_dataset_cifar(): + """Return a dataset dict""" + dataset = torchvision.datasets.CIFAR10(root='./cifar', train=True, download=True, + transform=transforms.ToTensor()) + + obs, acts = [], [] + for i in range(len(dataset)): + img, label = dataset[i] + obs.append(img.cpu().numpy()) + acts.append(label) + + obs = np.stack([o for o in obs], axis=0) + acts = np.array(acts) + + data_dict = { + 'obs': obs, + 'acts': acts, + 'dones': np.array([False] * len(dataset)), + } + + return data_dict + + +class MockGymEnv(object): + """A mock Gym env for a supervised learning dataset pretending to be an RL + task. Action space is set to Discrete(1), observation space corresponds to + the original supervised learning task. + """ + def __init__(self): + self.observation_space = Box(low=0.0, high=1.0, shape=(3, 32, 32), dtype=np.float32) + self.action_space = Discrete(1) + self.color_space = ColorSpace.RGB + + def seed(self, seed): + pass + + def close(self): + pass \ No newline at end of file diff --git a/src/il_representations/envs/config.py b/src/il_representations/envs/config.py index 6a6c3bff..0f3ca505 100644 --- a/src/il_representations/envs/config.py +++ b/src/il_representations/envs/config.py @@ -6,7 +6,9 @@ from sacred import Ingredient -ALL_BENCHMARK_NAMES = {"atari", "magical", "dm_control", "minecraft"} + +ALL_BENCHMARK_NAMES = {"atari", "magical", "dm_control", "minecraft", "procgen", "cifar-10"} + # see env_cfg_defaults docstring for description of this ingredient env_cfg_ingredient = Ingredient('env_cfg') @@ -70,6 +72,17 @@ def env_cfg_defaults(): # ############################### minecraft_max_env_steps = None + # ############################### + # CIFAR-10-specific config variables + # (none currently present) + # ############################### + + # ############################### + # Procgen-specific config variables + # ############################### + procgen_frame_stack = 4 + + _ = locals() del _ @@ -171,5 +184,12 @@ def env_data_defaults(): 'data/atari/PongNoFrameskip-v4_rollouts_500_ts_100_traj.npz', } + # ########################### + # ProcGen config variables + # ########################### + procgen_demo_paths = { + 'coinrun': 'procgen/demo_coinrun.pickle' + } + _ = locals() del _ diff --git a/src/il_representations/envs/procgen_envs.py b/src/il_representations/envs/procgen_envs.py new file mode 100644 index 00000000..9062b1ac --- /dev/null +++ b/src/il_representations/envs/procgen_envs.py @@ -0,0 +1,63 @@ +import os +import random +import numpy as np + +from procgen.gym_registration import make_env, register_environments + +from il_representations.envs.config import (env_cfg_ingredient, + env_data_ingredient) + + +@env_data_ingredient.capture +def _get_procgen_data_opts(data_root, procgen_demo_paths): + # workaround for Sacred issue #206 + return data_root, procgen_demo_paths + + +@env_cfg_ingredient.capture +def load_dataset_procgen(task_name, procgen_frame_stack, n_traj=None, + chans_first=True): + data_root, procgen_demo_paths = _get_procgen_data_opts() + + # load trajectories from disk + full_rollouts_path = os.path.join(data_root, procgen_demo_paths[task_name]) + trajectories = np.load(full_rollouts_path, allow_pickle=True) + + cat_obs = np.concatenate(trajectories['obs'], axis=0) + cat_acts = np.concatenate(trajectories['acts'], axis=0) + cat_rews = np.concatenate(trajectories['rews'], axis=0) + cat_dones = np.concatenate(trajectories['dones'], axis=0) + + dataset_dict = { + 'obs': cat_obs, + 'acts': cat_acts, + 'rews': cat_rews, + 'dones': cat_dones, + } + + if chans_first: + for key in ('obs', ): + dataset_dict[key] = np.transpose(dataset_dict[key], (0, 3, 1, 2)) + dataset_dict['obs'] = _stack_obs_oldest_first(dataset_dict['obs'], + procgen_frame_stack) + + return dataset_dict + + +@env_cfg_ingredient.capture +def get_procgen_env_name(task_name): + return f'procgen-{task_name}-v0' + + +@env_cfg_ingredient.capture +def _stack_obs_oldest_first(obs_arr, procgen_frame_stack): + frame_accumulator = np.repeat([obs_arr[0]], procgen_frame_stack, axis=0) + c, h, w = obs_arr.shape[1:] + out_sequence = [] + for in_frame in obs_arr: + frame_accumulator = np.concatenate( + [frame_accumulator[1:], [in_frame]], axis=0) + out_sequence.append(frame_accumulator.reshape( + procgen_frame_stack * c, h, w)) + out_sequence = np.stack(out_sequence, axis=0) + return out_sequence diff --git a/src/il_representations/scripts/chain_configs.py b/src/il_representations/scripts/chain_configs.py index ad3d51bb..0ece1f29 100644 --- a/src/il_representations/scripts/chain_configs.py +++ b/src/il_representations/scripts/chain_configs.py @@ -52,6 +52,23 @@ def cfg_base_3seed_4cpu_pt3gpu(): _ = locals() del _ + @experiment_obj.named_config + def cfg_base_3seed_1cpu_pt2gpu(): + """Basic config that does three samples per config, using 1 CPU cores and + 0.2 of a GPU.""" + use_skopt = False + tune_run_kwargs = dict(num_samples=3, + # retry on (node) failure + max_failures=2, + fail_fast=False, + resources_per_trial=dict( + cpu=1, + gpu=0.2, + )) + + _ = locals() + del _ + @experiment_obj.named_config def cfg_base_3seed_1cpu_pt2gpu_2envs(): """Another config that uses only one CPU per run, and .2 of a GPU. Good for @@ -206,6 +223,20 @@ def cfg_bench_micro_sweep_dm_control(): _ = locals() del _ + @experiment_obj.named_config + def cfg_run_few_trajs_long_dm_control(): + """For experiments running very few BC trajs""" + spec = dict(il_train={ + 'bc': { + 'n_batches': 10000000, + # 'n_trajs': tune.grid_search([1, 10, 30]), + 'save_every_n_batches': 5e4 + } + }) + + _ = locals() + del _ + @experiment_obj.named_config def cfg_bench_one_task_magical(): """Just one simple MAGICAL config.""" @@ -302,8 +333,10 @@ def cfg_repl_simclr(): stages_to_run = StagesToRun.REPL_AND_IL repl = { 'algo': 'SimCLR', + 'algo_params': { + 'optimizer_kwargs': {'lr': 3e-4}, + } } - _ = locals() del _ diff --git a/src/il_representations/scripts/il_test.py b/src/il_representations/scripts/il_test.py index a3bbe374..756b28f8 100644 --- a/src/il_representations/scripts/il_test.py +++ b/src/il_representations/scripts/il_test.py @@ -106,7 +106,8 @@ def run(policy_path, env_cfg, venv_opts, seed, n_rollouts, device_name, run_id, 'return_mean': eval_data_frame['mean_score'].mean(), } - elif (env_cfg['benchmark_name'] in ('dm_control', 'atari', 'minecraft')): + elif (env_cfg['benchmark_name'] in ('dm_control', 'atari', 'minecraft', + 'procgen')): # must import this to register envs from il_representations.envs import dm_control_envs # noqa: F401 diff --git a/src/il_representations/scripts/run_cifar.py b/src/il_representations/scripts/run_cifar.py new file mode 100644 index 00000000..c1464c52 --- /dev/null +++ b/src/il_representations/scripts/run_cifar.py @@ -0,0 +1,470 @@ +import numpy as np +import os +import PIL +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +from torchvision.models.resnet import resnet50 +from torchvision.datasets import CIFAR10 +from PIL import Image + +from tqdm import tqdm +from math import ceil +import pandas as pd + +import time +from sacred import Experiment +from sacred.observers import FileStorageObserver +from il_representations import algos +from il_representations.algos.utils import LinearWarmupCosine +from il_representations.envs.auto import load_wds_datasets +from il_representations.algos.encoders import SimCLRModel +from il_representations.envs.config import (env_cfg_ingredient, + env_data_ingredient, + venv_opts_ingredient) + + +cifar_ex = Experiment('cifar', ingredients=[ + env_cfg_ingredient, env_data_ingredient, + venv_opts_ingredient + ]) + + +class LinearHead(nn.Module): + def __init__(self, encoder, encoder_dim, output_dim): + super().__init__() + self.encoder = encoder + self.fc = nn.Linear(2048, output_dim, bias=True) + # self.encoder.fc = nn.Linear(2048, output_dim) + breakpoint() + + def forward(self, x): + x = self.encoder(x) + feature = torch.flatten(x, start_dim=1) + out = self.fc(feature) + return out + + +def train_classifier(classifier, data_dir, num_epochs, device): + # transform = transforms.Compose([ + # transforms.RandomResizedCrop(32, interpolation=PIL.Image.BICUBIC), + # transforms.RandomHorizontalFlip(), + # # No color jitter or grayscale for finetuning + # # SimCLR doesn't use blur for CIFAR-10 + # transforms.ToTensor(), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + # ]) + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(32), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, + 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, + 0.2010])]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, + 0.2010])]) + trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, + download=True, transform=train_transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) + criterion = nn.CrossEntropyLoss().to(device) + optimizer = optim.Adam(classifier.fc.parameters(), lr=1e-3, + weight_decay=1e-6) + # optimizer = optim.Adam(classifier.encoder.fc.parameters(), lr=3e-4, momentum=0.9, weight_decay=0.0, nesterov=True) + # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) + + # test_transform = transforms.Compose([ + # transforms.ToTensor(), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + # ]) + testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) + + progress_dict = {'loss': [], 'train_acc': [], 'test_acc': []} + + start_time = time.time() + + for epoch in range(num_epochs): + loss_meter = AverageMeter() + train_acc_meter = AverageMeter() + classifier.train() + + print(f"Epoch {epoch}/{num_epochs} with lr {optimizer.param_groups[0]['lr']}") + running_loss = 0.0 + for i, (inputs, labels) in enumerate(trainloader, 0): + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = classifier(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + train_acc_meter.update(accuracy(outputs, labels)[0].item()) + loss_meter.update(loss.item()) + running_loss += loss.item() + + if i % 20 == 19: # print every 20 mini-batches + hours, rem = divmod(time.time() - start_time, 3600) + minutes, seconds = divmod(rem, 60) + print(f"[{int(hours)}:{int(minutes)}:{int(seconds)}] " + f"Epoch {epoch}, Batch {i} " + f"Average loss: {loss_meter.avg} " + f"Average acc: {train_acc_meter.avg} " + f"Running loss: {running_loss / 20}") + running_loss = 0.0 + + #scheduler.step() + test_acc = evaluate_classifier(testloader, classifier, device) + + progress_dict['loss'].append(loss_meter.avg) + progress_dict['train_acc'].append(train_acc_meter.avg) + progress_dict['test_acc'].append(test_acc) + + with open('./progress.json', 'w') as f: + json.dump(progress_dict, f) + + +def evaluate_classifier(testloader, classifier, device): + # trainset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True) + # testloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) + classifier.eval() + total = 0 + test_acc_meter = AverageMeter() + with torch.no_grad(): + for images, labels in testloader: + images, labels = images.to(device), labels.to(device) + outputs = classifier(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + test_acc_meter.update(accuracy(outputs, labels)[0].item()) + print(f"Test acc: {test_acc_meter.avg}") + + return test_acc_meter.avg + + +def train_from_simclr_repo(model, batch_size, epochs): + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(32), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) + + train_data = CIFAR10(root='data', train=True, transform=train_transform, download=True) + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) + test_data = CIFAR10(root='data', train=False, transform=test_transform, download=True) + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) + + # flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) + # flops, params = clever_format([flops, params]) + # print('# Model Params: {} FLOPs: {}'.format(params, flops)) + optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6) + loss_criterion = nn.CrossEntropyLoss() + results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], + 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []} + + best_acc = 0.0 + for epoch in range(1, epochs + 1): + train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer, loss_criterion, + epoch, epochs) + results['train_loss'].append(train_loss) + results['train_acc@1'].append(train_acc_1) + results['train_acc@5'].append(train_acc_5) + test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None, loss_criterion, + epoch, epochs) + results['test_loss'].append(test_loss) + results['test_acc@1'].append(test_acc_1) + results['test_acc@5'].append(test_acc_5) + # save statistics + data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) + data_frame.to_csv('./linear_statistics.csv', index_label='epoch') + if test_acc_1 > best_acc: + best_acc = test_acc_1 + torch.save(model.state_dict(), './linear_model.pth') + + +# train or test for one epoch +def train_val(net, data_loader, train_optimizer, loss_criterion, epoch, epochs): + is_train = train_optimizer is not None + net.train() if is_train else net.eval() + + total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) + with (torch.enable_grad() if is_train else torch.no_grad()): + for data, target in data_bar: + data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) + out = net(data) + loss = loss_criterion(out, target) + + if is_train: + train_optimizer.zero_grad() + loss.backward() + train_optimizer.step() + + total_num += data.size(0) + total_loss += loss.item() * data.size(0) + prediction = torch.argsort(out, dim=-1, descending=True) + total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + + data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%' + .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num, + total_correct_1 / total_num * 100, total_correct_5 / total_num * 100)) + + return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100 + + +def representation_learning(algo, device, log_dir, config): + print('Train representation learner') + if isinstance(algo, str): + algo = getattr(algos, algo) + assert issubclass(algo, algos.RepresentationLearner) + + rep_learning_augmentations = transforms.Compose([ + transforms.Lambda(lambda x: np.transpose((x.cpu().numpy() * 255).astype(np.uint8), + axes=(1, 2, 0))), + transforms.ToPILImage(), + transforms.RandomResizedCrop(32, interpolation=PIL.Image.BICUBIC), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomApply([ + transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + # SimCLR doesn't use blur for CIFAR-10 + ]) + + rep_learning_data, combined_meta = load_wds_datasets([{}]) + augmenter_kwargs = { + "augmenter_spec": "translate,flip_lr,color_jitter_ex,gray", + "color_space": combined_meta['color_space'], + + # (Cynthia) Here I'm using augmenter_func because I want our settings + # to be as close to SimCLR as possible + "augment_func": rep_learning_augmentations + } + optimizer_kwargs = { + "lr": 1e-3, + "weight_decay": 1e-6 + } + + # This is currently erroneously 1 + num_epochs = config['pretrain_epochs'] + batch_size = config['pretrain_batch_size'] + batches_per_epoch = config['pretrain_batches_per_epoch'] + + # Modify resnet according to SimCLR paper Appendix B.9 + # simclr_resnet = resnet50() + # simclr_resnet.fc = torch.nn.Linear(2048, config['representation_dim']) + # simclr_resnet.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)) + # simclr_resnet.maxpool = torch.nn.Identity() + + model = algo( + observation_space=combined_meta['observation_space'], + action_space=combined_meta['action_space'], + log_dir=log_dir, + batch_size=batch_size, + representation_dim=config['representation_dim'], + projection_dim=config['projection_dim'], + device=device, + normalize=False, + shuffle_batches=True, + color_space=combined_meta['color_space'], + save_interval=config['pretrain_save_interval'], + encoder_kwargs={'obs_encoder_cls': lambda *args: SimCLRModel()}, + decoder_kwargs={'projection_architecture': [{'output_dim': 512}]}, + augmenter_kwargs=augmenter_kwargs, + optimizer=torch.optim.Adam, + optimizer_kwargs=optimizer_kwargs, + #scheduler=LinearWarmupCosine, + scheduler_kwargs={'warmup_epoch': 2, 'total_epochs': num_epochs}, + loss_calculator_kwargs={'temp': config['pretrain_temperature'], + 'use_repo_loss': config['use_repo_loss']}, + log_interval=1, + calc_log_interval=1 + ) + _, encoder_checkpoint_path = model.learn(rep_learning_data, batches_per_epoch, num_epochs) + print("Representation Learning trained!") + pretrained_model = torch.load(encoder_checkpoint_path) + return pretrained_model + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +# test for one epoch, use weighted knn to find the most similar images' label to assign the test image +def test(net, memory_data_loader, test_data_loader, k, num_classes, temperature, epoch): + net.eval() + total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] + with torch.no_grad(): + # generate feature bank + for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'): + feature = net(data.cuda(non_blocking=True)) + feature_bank.append(feature) + # [D, N] + feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() + # [N] + feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) + # loop test data to predict the label by weighted knn search + test_bar = tqdm(test_data_loader) + for data, _, target in test_bar: + data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) + feature = net(data) + + total_num += data.size(0) + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices) + sim_weight = (sim_weight / temperature).exp() + + # counts for each class + one_hot_label = torch.zeros(data.size(0) * k, num_classes, device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, num_classes) * sim_weight.unsqueeze(dim=-1), dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() + test_bar.set_description('Test Epoch: [{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' + .format(epoch, total_top1 / total_num * 100, total_top5 / total_num * 100)) + + return total_top1 / total_num * 100, total_top5 / total_num * 100 + +## data handling class copied from SimCLR implementation +class CIFAR10Pair(CIFAR10): + """CIFAR10 Dataset. + """ + + def __getitem__(self, index): + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img) + + if self.transform is not None: + pos_1 = self.transform(img) + pos_2 = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return pos_1, pos_2, target + + +@cifar_ex.config +def default_config(): + seed = 1 + algo = 'SimCLR' + data_dir = 'cifar10/' + pretrain_epochs = 1000 + pretrain_batches_per_epoch = 390 + finetune_epochs = 100 + finetune_batch_size = 512 + representation_dim = 2048 # TODO change back + projection_dim = 128 + pretrain_lr = 3e-4 + pretrain_weight_decay = 1e-4 + pretrain_momentum = 0.9 + pretrain_batch_size = 512 + pretrain_save_interval = 100 + pretrain_temperature = 0.1 + pretrained_model = None + use_repo_loss = False + eval_knn = True + _ = locals() + del _ + + +@cifar_ex.main +def run(seed, algo, data_dir, pretrain_epochs, finetune_epochs, representation_dim, + pretrained_model, pretrain_batch_size, finetune_batch_size, eval_knn, _config): + # TODO fix this hacky nonsense + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if pretrained_model is None: + log_dir = os.path.join(cifar_ex.observers[0].dir, 'training_logs') + os.mkdir(log_dir) + os.makedirs(data_dir, exist_ok=True) + + model = representation_learning(algo, device, log_dir, _config) + + else: + model = torch.load(pretrained_model) + + if eval_knn: + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) + + memory_data = CIFAR10Pair(root='data', train=True, transform=test_transform, download=True) + memory_loader = torch.utils.data.DataLoader(memory_data, batch_size=pretrain_batch_size, shuffle=False, + num_workers=16, pin_memory=True) + test_data = CIFAR10Pair(root='data', train=False, transform=test_transform, download=True) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=pretrain_batch_size, shuffle=False, + num_workers=16, pin_memory=True) + + # KNN testing from SimCLR repo for comparison + test(model.network, memory_loader, test_loader, k=200, num_classes=10, + temperature=_config['pretrain_temperature'], epoch=-1) + else: + print('Train linear head') + classifier = LinearHead(model.network, representation_dim, output_dim=10).to(device) + train_from_simclr_repo(classifier, finetune_batch_size, finetune_epochs) + # train_classifier(classifier, data_dir, num_epochs=finetune_epochs, device=device) + + print('Evaluate accuracy on test set') + evaluate_classifier(classifier, data_dir, device=device) + + +if __name__ == '__main__': + cifar_ex.observers.append(FileStorageObserver('runs/cifar_runs')) + cifar_ex.run_commandline() diff --git a/src/il_representations/scripts/run_il.sh b/src/il_representations/scripts/run_il.sh index ac9fef06..bfad3935 100755 --- a/src/il_representations/scripts/run_il.sh +++ b/src/il_representations/scripts/run_il.sh @@ -1,16 +1,16 @@ #!/usr/bin/env bash -CUDA_VISIBLE_DEVICES=3 xvfb-run -a python src/il_representations/scripts/pretrain_n_adapt.py with \ - cfg_base_3seed_1cpu_pt2gpu_2envs \ +CUDA_VISIBLE_DEVICES=1 xvfb-run -a python src/il_representations/scripts/pretrain_n_adapt.py with \ cfg_repl_none \ cfg_il_bc_nofreeze \ - tune_run_kwargs.num_samples=2 \ - tune_run_kwargs.resources_per_trial.gpu=0.5 \ - exp_ident=magical-small \ - il_train.bc.n_batches=400000 \ - il_train.bc.batch_size=512 \ - il_train.encoder_kwargs.obs_encoder_cls=MAGICALCNN \ - il_train.encoder_kwargs.obs_encoder_cls_kwargs.arch_str=MAGICALCNN-small \ - env_cfg.benchmark_name=dm_control \ - env_cfg.task_name=finger-spin + cfg_bench_micro_sweep_dm_control \ + cfg_run_few_trajs_long_dm_control \ + il_train.bc.n_trajs=10 \ + exp_ident=dmc_long_ntrajs_10 \ + tune_run_kwargs.num_samples=1 \ + tune_run_kwargs.resources_per_trial.gpu=0.3 \ + # il_train.bc.n_batches=400000 \ + # il_train.bc.nominal_length=10000 \ + # env_cfg.benchmark_name=dm_control \ + # env_cfg.task_name=finger-spin diff --git a/src/il_representations/scripts/run_simclr.sh b/src/il_representations/scripts/run_simclr.sh new file mode 100755 index 00000000..61e45bba --- /dev/null +++ b/src/il_representations/scripts/run_simclr.sh @@ -0,0 +1,18 @@ +repl_epochs=100 +bc_trajs=10 +bc_batches=4000000 + +CUDA_VISIBLE_DEVICES=0 python src/il_representations/scripts/pretrain_n_adapt.py with \ + cfg_repl_simclr \ + cfg_il_bc_nofreeze \ + cfg_bench_micro_sweep_dm_control \ + tune_run_kwargs.num_samples=1 \ + tune_run_kwargs.resources_per_trial.gpu=0.3 \ + repl.n_epochs=$repl_epochs \ + repl.n_trajs=$bc_trajs \ + repl.algo_params.batch_size=256 \ + il_train.bc.n_trajs=$bc_trajs \ + il_train.bc.n_batches=$bc_batches \ + exp_ident=repl_epoch_${repl_epochs}_bc_${bc_trajs}_trajs_${bc_batches}_batches + + # repl.algo_params.representation_dim=512 \