Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions benchmarks/ppo/continual_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ def __init__(
gradient_accumulation_steps=args.gradient_accumulation_steps
)
self.accelerator = accelerator
self.gather_function = self.accelerator.gather_for_metrics
ContinualPPOTrainer.shared_accelerator = accelerator
else:
elif False:
self.accelerator = ContinualPPOTrainer.shared_accelerator
self.gather_function = self.accelerator.gather_for_metrics
if (
Expand Down Expand Up @@ -336,7 +337,7 @@ def __init__(
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
ContinualPPOTrainer.policy_value_models = self.model
self.model.config = self.policy_model.config # needed for pushing to hub
else:
elif False:
# Subsequent tasks: Reuse existing model
self.model = ContinualPPOTrainer.policy_value_models
self.model.config = self.policy_model.config # needed for pushing to hub
Expand Down Expand Up @@ -407,7 +408,7 @@ def __init__(
self.model, self.optimizer, self.dataloader
)
ContinualPPOTrainer.ds_wrapped_models = self.model
else:
elif False:
# For subsequent tasks, only prepare optimizer and dataloader
self.optimizer, self.dataloader = self.accelerator.prepare(
self.optimizer, self.dataloader
Expand Down Expand Up @@ -971,6 +972,7 @@ def repeat_generator() -> DataLoader:
ContinualPPOTrainer.class_ref_model = original_ref_model

# Ensure the class variable is updated
# TODO: Double check this is fine to keep
ContinualPPOTrainer.class_ref_model = self.ref_model
if self.is_deepspeed_enabled:
ContinualPPOTrainer.ds_wrapped_models = self.deepspeed
Expand Down
36 changes: 20 additions & 16 deletions benchmarks/ppo/ppo_continual.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@ def main(
quantization_config=quantization_config,
)

# Load main model and (optionally) reference model
model = str(training_args.sft_model_path)
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
Expand All @@ -69,13 +63,6 @@ def main(
else:
ref_policy = None

# Load value model and policy model (main model)
value_model = AutoModelForSequenceClassification.from_pretrained(
script_args.value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
)

# Load tokenizer and set chat template if needed
tokenizer = AutoTokenizer.from_pretrained(
training_args.sft_model_path,
Expand Down Expand Up @@ -117,6 +104,24 @@ def main(

# Task Loop
for i, dataset in enumerate(continual_dataset):
# Load main model and (optionally) reference model
if i == 0:
model_path = training_args.sft_model_path
else:
model_path = os.path.join(training_args.output_dir, 'last')
policy = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)

# Load value model and policy model (main model)
value_model = AutoModelForSequenceClassification.from_pretrained(
script_args.value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
)

# Build custom repository name for this task
custom_repo_name = (
model.split('/')[-1] + '_' + clean_dataset_name + '_PPO_' + str(i)
Expand Down Expand Up @@ -174,9 +179,8 @@ def main(
wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined]

# Save model checkpoint and optionally push
if not training_args.push_to_hub:
trainer.save_model(os.path.join(training_args.output_dir, 'last'))
else:
trainer.save_model(os.path.join(training_args.output_dir, 'last'))
if training_args.push_to_hub:
trainer.push_to_hub(
model_name=custom_repo_name,
dataset_name='Continual_PPO_' + clean_dataset_name + '_' + str(i),
Expand Down
Loading