From 76caa860eb1790b131a84767e0840933993c79d5 Mon Sep 17 00:00:00 2001 From: Amulya Gupta Date: Mon, 24 May 2021 16:30:18 -0500 Subject: [PATCH] Adding AMP support in pretraining. --- oscar/run_oscarplus_pretrain.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/oscar/run_oscarplus_pretrain.py b/oscar/run_oscarplus_pretrain.py index 3dab083..0cbfb22 100644 --- a/oscar/run_oscarplus_pretrain.py +++ b/oscar/run_oscarplus_pretrain.py @@ -359,6 +359,7 @@ def main(): # Every args.ckpt_period, report train_score and save model tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 + scaler = torch.cuda.amp.GradScaler(enabled=True) for step, (batch, batch_extra) in enumerate(zip(train_dataloader, train_dataloader_extra), start_iter): if not clock_started: start_training_time = time.time() @@ -391,17 +392,19 @@ def forward_backward(images, input_ids, input_mask, segment_ids, # feature as input image_features = torch.stack(images).to(args.device, non_blocking=True) - outputs = model(input_ids, segment_ids, input_mask, - lm_label_ids, is_next, img_feats=image_features) + with torch.cuda.amp.autocast(enabled=True): + outputs = model(input_ids, segment_ids, input_mask, + lm_label_ids, is_next, img_feats=image_features) - loss = loss_weight * outputs[0] + loss = loss_weight * outputs[0] if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps - loss.backward() + scaler.scale(loss).backward() + # loss.backward() return loss.item(), input_ids.size(0) @@ -436,7 +439,9 @@ def forward_backward(images, input_ids, input_mask, segment_ids, if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) # do the optimization steps - optimizer.step() + # optimizer.step() + scaler.step(optimizer) + scaler.update() scheduler.step() # Update learning rate schedule optimizer.zero_grad()