diff --git a/train_first.py b/train_first.py index 8a319d4..67dee8e 100644 --- a/train_first.py +++ b/train_first.py @@ -118,6 +118,7 @@ def main(config_path): if config.get('pretrained_model', '') != '': model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'], load_only_params=config.get('load_only_params', True)) + start_epoch = start_epoch + 1 else: start_epoch = 0 iters = 0 diff --git a/train_second.py b/train_second.py index 32fd30a..87560a1 100644 --- a/train_second.py +++ b/train_second.py @@ -132,6 +132,7 @@ def main(config_path): if config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False): model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'], load_only_params=config.get('load_only_params', True)) + start_epoch = start_epoch + 1 else: start_epoch = 0 iters = 0