diff --git a/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py b/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py index afcee721f8..27c523efa8 100644 --- a/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py +++ b/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py @@ -1,3 +1,4 @@ +import math from bisect import bisect_right from torch.optim import Adam, Adagrad @@ -25,6 +26,9 @@ def __init__(self, model, args): "text_span_decoder": decoder_lr_schedules_list, } self.lr_ratio = args.lr_ratio + self.lr_scheduler = args.lr_scheduler + self.iter_per_epoch = args.dataset_size / args.batch_size + self.num_epochs = args.num_epochs # setup warmup stage self.warmup_steps = { "encoder": args.encoder_warmup_steps, @@ -60,9 +64,26 @@ def _update_rate(self, stack): alpha = self._step / self.warmup_steps[stack] return self.lr[stack] * (self.warmup_factor * (1.0 - alpha) + alpha) else: - return self.lr[stack] * self.lr_ratio ** bisect_right( - self.lr_schedules[stack], self._step - ) + num_training_steps = self.num_epochs * self.iter_per_epoch + if self.lr_scheduler == "constant": + return self.lr[stack] + elif self.lr_scheduler == "linear": + return max( + 0.0, float(num_training_steps - self._step) / float(max(1, num_training_steps - self.warmup_steps[stack])) + ) + elif self.lr_scheduler == "cosine": + num_cycles = 0.5 + progress = float(self._step - self.warmup_steps[stack]) / float(max(1, num_training_steps - self.warmup_steps[stack])) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + elif self.lr_scheduler == "cosine_with_hardstop": + progress = float(self._step - self.warmup_steps[stack]) / float(max(1, num_training_steps - self.warmup_steps[stack])) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + else: + return self.lr[stack] * self.lr_ratio ** bisect_right( + self.lr_schedules[stack], self._step + ) def zero_grad(self): self.optimizer_decoder.zero_grad() diff --git a/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py b/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py index fa600e02cd..5a1bb32cf0 100644 --- a/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py +++ b/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py @@ -515,6 +515,12 @@ def build_grammar(args): type=float, help="Factor for learning rate in warmup stage", ) + parser.add_argument( + "--lr_scheduler", + default="constant", + type=str, + help="Different schedulers", + ) parser.add_argument( "--node_label_smoothing", default=0.0, @@ -657,6 +663,7 @@ def build_grammar(args): logging.info("====== Initializing NLU Model Trainer ======") if args.cuda: encoder_decoder = encoder_decoder.cuda() + args.dataset_size = len(train_dataset) model_trainer = NLUModelTrainer( args, encoder_decoder, tokenizer, model_identifier, full_tree_voc )