diff --git a/benchmarks/ppo/continual_ppo_trainer.py b/benchmarks/ppo/continual_ppo_trainer.py index cb3d747f..4865b2b9 100644 --- a/benchmarks/ppo/continual_ppo_trainer.py +++ b/benchmarks/ppo/continual_ppo_trainer.py @@ -270,8 +270,9 @@ def __init__( gradient_accumulation_steps=args.gradient_accumulation_steps ) self.accelerator = accelerator + self.gather_function = self.accelerator.gather_for_metrics ContinualPPOTrainer.shared_accelerator = accelerator - else: + elif False: self.accelerator = ContinualPPOTrainer.shared_accelerator self.gather_function = self.accelerator.gather_for_metrics if ( @@ -336,7 +337,7 @@ def __init__( self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) ContinualPPOTrainer.policy_value_models = self.model self.model.config = self.policy_model.config # needed for pushing to hub - else: + elif False: # Subsequent tasks: Reuse existing model self.model = ContinualPPOTrainer.policy_value_models self.model.config = self.policy_model.config # needed for pushing to hub @@ -407,7 +408,7 @@ def __init__( self.model, self.optimizer, self.dataloader ) ContinualPPOTrainer.ds_wrapped_models = self.model - else: + elif False: # For subsequent tasks, only prepare optimizer and dataloader self.optimizer, self.dataloader = self.accelerator.prepare( self.optimizer, self.dataloader @@ -971,6 +972,7 @@ def repeat_generator() -> DataLoader: ContinualPPOTrainer.class_ref_model = original_ref_model # Ensure the class variable is updated + # TODO: Double check this is fine to keep ContinualPPOTrainer.class_ref_model = self.ref_model if self.is_deepspeed_enabled: ContinualPPOTrainer.ds_wrapped_models = self.deepspeed diff --git a/benchmarks/ppo/ppo_continual.py b/benchmarks/ppo/ppo_continual.py index 8db6aff3..b62aece9 100644 --- a/benchmarks/ppo/ppo_continual.py +++ b/benchmarks/ppo/ppo_continual.py @@ -52,13 +52,7 @@ def main( quantization_config=quantization_config, ) - # Load main model and (optionally) reference model model = str(training_args.sft_model_path) - policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) peft_config = get_peft_config(model_args) if peft_config is None: ref_policy = AutoModelForCausalLM.from_pretrained( @@ -69,13 +63,6 @@ def main( else: ref_policy = None - # Load value model and policy model (main model) - value_model = AutoModelForSequenceClassification.from_pretrained( - script_args.value_model_path, - trust_remote_code=model_args.trust_remote_code, - num_labels=1, - ) - # Load tokenizer and set chat template if needed tokenizer = AutoTokenizer.from_pretrained( training_args.sft_model_path, @@ -117,6 +104,24 @@ def main( # Task Loop for i, dataset in enumerate(continual_dataset): + # Load main model and (optionally) reference model + if i == 0: + model_path = training_args.sft_model_path + else: + model_path = os.path.join(training_args.output_dir, 'last') + policy = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + + # Load value model and policy model (main model) + value_model = AutoModelForSequenceClassification.from_pretrained( + script_args.value_model_path, + trust_remote_code=model_args.trust_remote_code, + num_labels=1, + ) + # Build custom repository name for this task custom_repo_name = ( model.split('/')[-1] + '_' + clean_dataset_name + '_PPO_' + str(i) @@ -174,9 +179,8 @@ def main( wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined] # Save model checkpoint and optionally push - if not training_args.push_to_hub: - trainer.save_model(os.path.join(training_args.output_dir, 'last')) - else: + trainer.save_model(os.path.join(training_args.output_dir, 'last')) + if training_args.push_to_hub: trainer.push_to_hub( model_name=custom_repo_name, dataset_name='Continual_PPO_' + clean_dataset_name + '_' + str(i),