diff --git a/robustness/attacker.py b/robustness/attacker.py index 4e51bfa..33584ac 100644 --- a/robustness/attacker.py +++ b/robustness/attacker.py @@ -271,6 +271,7 @@ class AttackerModel(ch.nn.Module): def __init__(self, model, dataset): super(AttackerModel, self).__init__() self.normalizer = helpers.InputNormalize(dataset.mean, dataset.std) + self.normalizer = ch.jit.script(self.normalizer) self.model = model def forward(self, inp): diff --git a/robustness/imagenet_models/resnet.py b/robustness/imagenet_models/resnet.py index 8d093b4..498cbea 100644 --- a/robustness/imagenet_models/resnet.py +++ b/robustness/imagenet_models/resnet.py @@ -7,7 +7,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', - 'wide_resnet50_3', 'wide_resnet50_4', 'wide_resnet50_5', + 'wide_resnet50_3', 'wide_resnet50_4', 'wide_resnet50_5', 'wide_resnet50_6', ] @@ -57,7 +57,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x, fake_relu=False, no_relu=False): + def forward(self, x): identity = x out = self.conv1(x) @@ -72,10 +72,6 @@ def forward(self, x, fake_relu=False, no_relu=False): out += identity - if fake_relu: - return FakeReLU.apply(out) - if no_relu: - return out return self.relu(out) @@ -201,10 +197,10 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) - return SequentialWithArgs(*layers) + return nn.Sequential(*layers) # return nn.Sequential(*layers) - def _forward(self, x, with_latent=False, fake_relu=False, no_relu=False): + def _forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -213,13 +209,11 @@ def _forward(self, x, with_latent=False, fake_relu=False, no_relu=False): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) - x = self.layer4(x, fake_relu=fake_relu, no_relu=no_relu) + x = self.layer4(x) x = self.avgpool(x) pre_out = torch.flatten(x, 1) final = self.fc(pre_out) - if with_latent: - return final, pre_out return final # Allow for accessing forward method in a inherited class @@ -360,7 +354,7 @@ def wide_resnet50_3(pretrained=False, progress=True, **kwargs): def wide_resnet50_4(pretrained=False, progress=True, **kwargs): - r"""Wide ResNet-50-4 model + r"""Wide ResNet-50-4 model Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr diff --git a/robustness/main.py b/robustness/main.py index 806b84a..2b13663 100644 --- a/robustness/main.py +++ b/robustness/main.py @@ -56,7 +56,7 @@ def main(args, store=None): if not args.resume_optimizer: checkpoint = None model = train_model(args, model, loaders, store=store, - checkpoint=checkpoint) + checkpoint=checkpoint, val_loader=val_loader) return model def setup_args(args): diff --git a/robustness/tools/custom_modules.py b/robustness/tools/custom_modules.py index a82295c..5b8ab69 100644 --- a/robustness/tools/custom_modules.py +++ b/robustness/tools/custom_modules.py @@ -1,4 +1,4 @@ -import torch +import torch from torch import nn ch = torch @@ -16,12 +16,12 @@ def forward(self, x): return FakeReLU.apply(x) class SequentialWithArgs(torch.nn.Sequential): - def forward(self, input, *args, **kwargs): + def forward(self, input): vs = list(self._modules.values()) l = len(vs) for i in range(l): if i == l-1: - input = vs[i](input, *args, **kwargs) + input = vs[i](input) else: input = vs[i](input) return input diff --git a/robustness/tools/helpers.py b/robustness/tools/helpers.py index cdc7501..272803b 100644 --- a/robustness/tools/helpers.py +++ b/robustness/tools/helpers.py @@ -96,63 +96,141 @@ def __init__(self, new_mean, new_std): self.register_buffer("new_std", new_std) def forward(self, x): - x = ch.clamp(x, 0, 1) + # x = ch.clamp(x, 0, 1) x_normalized = (x - self.new_mean)/self.new_std return x_normalized +# class DataPrefetcher(): +# def __init__(self, loader, stop_after=None): +# self.loader = loader +# self.dataset = loader.dataset + +# def __len__(self): +# return len(self.loader) + +# def __iter__(self): +# # count = 0 +# # self.loaditer = iter(self.loader) +# # self.preload() +# # for i in range(loaditer) +# # while self.next_input is not None: +# # ch.cuda.current_stream().wait_stream(self.stream) +# # input = self.next_input +# # target = self.next_target +# # self.preload() +# # count += 1 +# # yield input, target +# # if type(self.stop_after) is int and (count > self.stop_after): +# # break +# prev_x, prev_y = None, None +# for _, (x, y) in enumerate(self.loader): +# x = x.to(device='cuda', memory_format=ch.channels_last, +# non_blocking=True).to(dtype=ch.float32, non_blocking=True) +# y = y.to(device='cuda', non_blocking=True) +# if prev_x is None: +# prev_x, prev_y = x, y +# else: +# yield prev_x, prev_y +# prev_x, prev_y = x, y + +# yield prev_x, prev_y + class DataPrefetcher(): def __init__(self, loader, stop_after=None): self.loader = loader - self.dataset = loader.dataset - self.stream = ch.cuda.Stream() - self.stop_after = stop_after - self.next_input = None - self.next_target = None + self.dataset = self.loader.dataset def __len__(self): return len(self.loader) + def __iter__(self): + prefetcher = data_prefetcher(self.loader) + x, y = prefetcher.next() + while x is not None: + yield x, y + x, y = prefetcher.next() + +class data_prefetcher(): + def __init__(self, loader): + self.loader = iter(loader) + self.stream = ch.cuda.Stream() + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.mean = self.mean.half() + # self.std = self.std.half() + self.preload() + def preload(self): try: - self.next_input, self.next_target = next(self.loaditer) + self.next_input, self.next_target = next(self.loader) except StopIteration: self.next_input = None self.next_target = None return + # if record_stream() doesn't work, another option is to make sure device inputs are created + # on the main stream. + # self.next_input_gpu = ch.empty_like(self.next_input, device='cuda') + # self.next_target_gpu = ch.empty_like(self.next_target, device='cuda') + # Need to make sure the memory allocated for next_* is not still in use by the main stream + # at the time we start copying to next_*: + # self.stream.wait_stream(ch.cuda.current_stream()) with ch.cuda.stream(self.stream): - self.next_input = self.next_input.cuda(non_blocking=True) + # self.next_input = self.next_input.cuda(non_blocking=True) self.next_target = self.next_target.cuda(non_blocking=True) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu - def __iter__(self): - count = 0 - self.loaditer = iter(self.loader) + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.next_input = self.next_input.half() + # else: + self.next_input = self.next_input.to(device='cuda', memory_format=ch.channels_last, + non_blocking=True).to(dtype=ch.float32, non_blocking=True) + # to(dtype=ch.float32, non_blocking=True) + + def next(self): + ch.cuda.current_stream().wait_stream(self.stream) + input = self.next_input + target = self.next_target + if input is not None: + input.record_stream(ch.cuda.current_stream()) + if target is not None: + target.record_stream(ch.cuda.current_stream()) self.preload() - while self.next_input is not None: - ch.cuda.current_stream().wait_stream(self.stream) - input = self.next_input - target = self.next_target - self.preload() - count += 1 - yield input, target - if type(self.stop_after) is int and (count > self.stop_after): - break + return input, target + class AverageMeter(object): """Computes and stores the average and current value""" - def __init__(self): + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt self.reset() def reset(self): self.val = 0 - self.avg = 0 self.sum = 0 self.count = 0 - def update(self, val, n=1): + @property + def avg(self): + return self.sum / max(self.count, 1) + + def update(self, val, _=0): self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count + self.sum += val + self.count += 1 + # self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg2' + self.fmt + '} ({sum' + self.fmt + '})' + self.avg2 = self.avg + # self.avg = (self.sum / max(self.count, 1)) + return fmtstr.format(**self.__dict__) # ImageNet label mappings def get_label_mapping(dataset_name, ranges): diff --git a/robustness/train.py b/robustness/train.py index e33d6cf..9872063 100644 --- a/robustness/train.py +++ b/robustness/train.py @@ -3,6 +3,7 @@ import torch.nn as nn from torch.optim import SGD, Adam, lr_scheduler from torchvision.utils import make_grid +from torch.nn.utils import parameters_to_vector as flatten from cox.utils import Parameters from .tools import helpers @@ -12,8 +13,10 @@ import os import time import warnings +from pytorch_loss import LabelSmoothSoftmaxCEV3 -from torch.cuda.amp import GradScaler, autocast +from torch.cuda.amp import autocast +from apex import amp if int(os.environ.get("NOTEBOOK_MODE", 0)) == 1: from tqdm import tqdm_notebook as tqdm @@ -54,7 +57,7 @@ def check_args(args_list): without a custom adversarial loss (see docs)") -def make_optimizer_and_schedule(args, model, checkpoint, params): +def make_optimizer_and_schedule(args, model, checkpoint, params, iters_per_epoch): """ *Internal Function* (called directly from train_model) @@ -76,9 +79,20 @@ def make_optimizer_and_schedule(args, model, checkpoint, params): """ # Make optimizer param_list = model.parameters() if params is None else params + if args.optimizer == 'Adam': optimizer = Adam(param_list, lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay) + elif args.optimizer == 'SGD_no_bn_wd': + all_params = {k: v for (k, v) in model.named_parameters()} + param_groups = [{ + 'params': [all_params[k] for k in all_params if ('bn' in k)], + 'weight_decay': 0. + }, { + 'params': [all_params[k] for k in all_params if not ('bn' in k)], + 'weight_decay': args.weight_decay + }] + optimizer = SGD(param_groups, args.lr, momentum=args.momentum) else: optimizer = SGD(param_list, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) @@ -88,23 +102,30 @@ def make_optimizer_and_schedule(args, model, checkpoint, params): if args.custom_lr_multiplier == 'reduce_on_plateau': schedule = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, mode='min') - elif args.custom_lr_multiplier[:6] == 'cyclic': + elif args.custom_lr_multiplier.startswith('cyclic'): # E.g. `cyclic_5` for peaking at 5 epochs - eps = args.epochs - peak = int(args.custom_lr_multiplier.split('_')[-1]) - lr_func = lambda t: np.interp([t+1], [0, peak, eps], [0, 1, 0])[0] - schedule = lr_scheduler.LambdaLR(optimizer, lr_func) + # peak = int(args.custom_lr_multiplier.split('_')[-1]) + peak = float(args.custom_lr_multiplier.split('_')[-1]) + cyc_lr_func = lambda t: np.interp([t+1], [0, peak, args.epochs], [0, 1, 0])[0] + schedule = lr_scheduler.LambdaLR(optimizer, cyc_lr_func) + elif args.custom_lr_multiplier.startswith('itercyclic'): + # peak = int(args.custom_lr_multiplier.split('_')[-1]) + peak = float(args.custom_lr_multiplier.split('_')[-1]) + # itercyc_lr_func = lambda t: np.interp([float(t+1) / iters_per_epoch], [0, peak, args.epochs], [0, 1, 0])[0] + itercyc_lr_func = lambda t: np.interp([float(t+1) / iters_per_epoch], [0, peak, args.epochs], [0, 1, 0])[0] + schedule = lr_scheduler.LambdaLR(optimizer, itercyc_lr_func) elif args.custom_lr_multiplier: cs = args.custom_lr_multiplier periods = eval(cs) if type(cs) is str else cs if args.lr_interpolation == 'linear': - lr_func = lambda t: np.interp([t], *zip(*periods))[0] + lin_lr_func = lambda t: np.interp([t], *zip(*periods))[0] else: def lr_func(ep): for (milestone, lr) in reversed(periods): if ep >= milestone: return lr return 1.0 - schedule = lr_scheduler.LambdaLR(optimizer, lr_func) + lin_lr_func = lr_func + schedule = lr_scheduler.LambdaLR(optimizer, lin_lr_func) elif args.step_lr: schedule = lr_scheduler.StepLR(optimizer, step_size=args.step_lr, gamma=args.step_lr_gamma) @@ -114,8 +135,8 @@ def lr_func(ep): return optimizer, schedule +""" def eval_model(args, model, loader, store): - """ Evaluate a model for standard (and optionally adversarial) accuracy. Args: @@ -125,7 +146,6 @@ def eval_model(args, model, loader, store): loader (iterable) : a dataloader serving `(input, label)` batches from the validation set store (cox.Store) : store for saving results in (via tensorboardX) - """ check_required_args(args, eval_only=True) start_time = time.time() @@ -158,6 +178,7 @@ def eval_model(args, model, loader, store): # Log info into the logs table if store: store[consts.LOGS_TABLE].append_row(log_info) return log_info +""" def train_model(args, model, data_aug, loaders, *, checkpoint=None, dp_device_ids=None, store=None, update_params=None, disable_no_grad=False, @@ -261,7 +282,7 @@ def train_model(args, model, data_aug, loaders, *, checkpoint=None, dp_device_id disable_no_grad (bool) : if True, then even model evaluation will be run with autograd enabled (otherwise it will be wrapped in a ch.no_grad()) """ - scaler = GradScaler() + # scaler = GradScaler() # Logging setup writer = store.tensorboard if store else None prec1_key = f"{'adv' if args.adv_train else 'nat'}_prec1" @@ -280,13 +301,16 @@ def train_model(args, model, data_aug, loaders, *, checkpoint=None, dp_device_id train_loader, val_loader = loaders if not args.opt_model_and_schedule: opt, schedule = make_optimizer_and_schedule(args, model, checkpoint, - update_params) + update_params, len(train_loader)) assert not hasattr(model, "module"), "model is already in DataParallel." model = ch.nn.DataParallel(model, device_ids=dp_device_ids).cuda() else: opt, _, schedule = args.opt_model_and_schedule - # Put the model into parallel mode + if args.custom_lr_multiplier.startswith('itercyclic'): + assert args.iteration_hook is None + def iter_hook(*_): schedule.step() + args.iteration_hook = iter_hook best_prec1, start_epoch = (0, 0) if checkpoint: @@ -299,10 +323,11 @@ def train_model(args, model, data_aug, loaders, *, checkpoint=None, dp_device_id start_time = time.time() for epoch in range(start_epoch, args.epochs): + if hasattr(train_loader.dataset, 'next_epoch'): + train_loader.dataset.next_epoch() # train for one epoch train_prec1, train_loss = _model_loop(args, 'train', train_loader, - model, opt, epoch, args.adv_train, writer, data_aug=data_aug, - scaler=scaler) + model, opt, epoch, args.adv_train, writer, data_aug=data_aug) last_epoch = (epoch == (args.epochs - 1)) # evaluate on validation set @@ -323,7 +348,7 @@ def save_checkpoint(filename): should_save_ckpt = (epoch % save_its == 0) and (save_its > 0) should_log = (epoch % args.log_iters == 0) - if should_log or last_epoch or should_save_ckpt: + if epoch > 0 and (should_log or last_epoch or should_save_ckpt): # log + get best ctx = ch.enable_grad() if disable_no_grad else ch.no_grad() with ctx: @@ -371,14 +396,14 @@ def save_checkpoint(filename): if schedule: if 'reduce_on_plateau' in args.custom_lr_multiplier: schedule.step(nat_loss) - else: + elif not args.custom_lr_multiplier.startswith('itercyclic'): schedule.step() if has_attr(args, 'epoch_hook'): args.epoch_hook(model, log_info) return model def _model_loop(args, loop_type, loader, model, opt, epoch, adv, writer, - data_aug=None, scaler=None): + data_aug=None): """ *Internal function* (refer to the train_model and eval_model functions for how to train and evaluate models). @@ -400,14 +425,16 @@ def _model_loop(args, loop_type, loader, model, opt, epoch, adv, writer, Returns: The average top1 accuracy and the average loss across the epoch. """ + if not loop_type in ['train', 'val']: err_msg = "loop_type ({0}) must be 'train' or 'val'".format(loop_type) raise ValueError(err_msg) is_train = (loop_type == 'train') - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() + # losses = AverageMeter('losses') + # losses = AverageMeter('Loss', ':.4e') + # top1 = AverageMeter() + # top5 = AverageMeter() prec = 'NatPrec' if not adv else 'AdvPrec' loop_msg = 'Train' if loop_type == 'train' else 'Val' @@ -423,96 +450,50 @@ def _model_loop(args, loop_type, loader, model, opt, epoch, adv, writer, # Custom training criterion has_custom_train_loss = has_attr(args, 'custom_train_loss') + default_train_crit = LabelSmoothSoftmaxCEV3(lb_smooth=args.label_smoothing) if \ + args.label_smoothing is not None else ch.nn.CrossEntropyLoss() train_criterion = args.custom_train_loss if has_custom_train_loss \ - else ch.nn.CrossEntropyLoss() + else default_train_crit has_custom_adv_loss = has_attr(args, 'custom_adv_loss') adv_criterion = args.custom_adv_loss if has_custom_adv_loss else None attack_kwargs = {} - if adv: - attack_kwargs = { - 'constraint': args.constraint, - 'eps': eps, - 'step_size': args.attack_lr, - 'iterations': args.attack_steps, - 'random_start': args.random_start, - 'custom_loss': adv_criterion, - 'random_restarts': random_restarts, - 'use_best': bool(args.use_best) - } - iterator = tqdm(enumerate(loader), total=len(loader)) + + should_crop = hasattr(loader.dataset, 'current_crop') + if should_crop: + crop_x, crop_y = loader.dataset.current_crop.numpy() + + total_correct, total = 0., 0. for i, (inp, target) in iterator: if loop_type == 'train': - inp = data_aug(inp) - - # If we have tensor cores we use channel_last - if ch.cuda.get_device_capability()[0] >= 8: - inp = inp.to(memory_format=ch.channels_last) - - # measure data loading time - target = target.cuda(non_blocking=True) - with autocast(): - output = model(inp) - loss = train_criterion(output, target) - - if len(loss.shape) > 0: loss = loss.mean() - - model_logits = output[0] if (type(output) is tuple) else output - - # measure accuracy and record loss - top1_acc = float('nan') - top5_acc = float('nan') - try: - maxk = min(5, model_logits.shape[-1]) - if has_attr(args, "custom_accuracy"): - prec1, prec5 = args.custom_accuracy(model_logits, target) - else: - prec1, prec5 = helpers.accuracy(model_logits, target, topk=(1, maxk)) - prec1, prec5 = prec1[0], prec5[0] - - losses.update(loss.item(), inp.size(0)) - top1.update(prec1, inp.size(0)) - top5.update(prec5, inp.size(0)) - - top1_acc = top1.avg - top5_acc = top5.avg - except Exception as e: - warnings.warn('Failed to calculate the accuracy.') - - reg_term = 0.0 - if has_attr(args, "regularizer"): - reg_term = args.regularizer(model, inp, target) - loss = loss + reg_term - - # compute gradient and do SGD step - if is_train: - opt.zero_grad() - scaler.scale(loss).backward() - scaler.step(opt) - scaler.update() - - # ITERATOR - desc = ('{2} Epoch:{0} | Loss {loss.avg:.8f} | ' - '{1}1 {top1_acc:.3f} | {1}5 {top5_acc:.6f} | ' - 'Reg term: {reg} ||'.format( epoch, prec, loop_msg, - loss=losses, top1_acc=top1_acc, top5_acc=top5_acc, reg=reg_term)) - - # USER-DEFINED HOOK - if has_attr(args, 'iteration_hook'): - args.iteration_hook(model, i, loop_type, inp, target) + with autocast(): + inp = data_aug(inp) + + if should_crop: + inp = inp[:, :, :crop_x, :crop_y] + output = model(inp) + loss = train_criterion(output, target) + # with ch.no_grad(): + # losses.update(loss) - iterator.set_description(desc) - iterator.refresh() + if is_train: + with amp.scale_loss(loss, opt) as scaled_loss: + scaled_loss.backward() + opt.step() + opt.zero_grad(set_to_none=True) + else: + corrects = output.argmax(1).eq(target) + total_correct += corrects.sum() + total += corrects.shape[0] - if writer is not None: - prec_type = 'adv' if adv else 'nat' - descs = ['loss', 'top1', 'top5'] - vals = [losses, top1, top5] - for d, v in zip(descs, vals): - writer.add_scalar('_'.join([prec_type, loop_type, d]), v.avg, - epoch) + if has_attr(args, 'iteration_hook'): + args.iteration_hook(None) - return top1.avg, losses.avg + if not is_train: + print(f'Val epoch {epoch}, accuracy {total_correct / total * 100:.2f}%', flush=True) + # print(f'{loop_msg} avg loss', losses.avg, flush=True) + # return 0., losses.avg + return 0., 0.