Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions benchmarks/ppo/continual_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class ContinualPPOConfig(PPOConfig):


class ContinualPPOTrainer(PPOTrainer):

def __init__(
self,
args: Optional[PPOConfig] = None,
Expand Down Expand Up @@ -142,7 +141,9 @@ def __init__(
self.shared_accelerator: Optional[Accelerator] = None
self.current_task_index: Optional[int] = None
self.policy_value_models: Any = None # the policy and value model wrapper
self.ds_wrapped_models: Any = None # TODO work with this after deepspeed is initialized
self.ds_wrapped_models: Any = (
None # TODO work with this after deepspeed is initialized
)
self.accelerator: Accelerator = None # now non-optional after creation

# Basic setup and validation
Expand Down Expand Up @@ -1192,13 +1193,12 @@ def mark_final_eval(self, is_final: bool = True) -> 'ContinualPPOTrainer':
return self

def save_model(self, output_dir: str, _internal_call=True) -> None:
"""
Manually save the model (and training state) to a specified directory.
"""Manually save the model (and training state) to a specified directory.
This follows a similar procedure as _save_checkpoint.
"""

# Save the model files to output_dir (marking _internal_call True)
from transformers import Trainer # ensure Trainer is imported

Trainer.save_model(self, output_dir, _internal_call=True)

# If not saving only the model, save optimizer, scheduler, and RNG state
Expand All @@ -1208,9 +1208,9 @@ def save_model(self, output_dir: str, _internal_call=True) -> None:
self._save_rng_state(output_dir)

# Save the trainer state
trainer_state_path = os.path.join(output_dir, "trainer_state.json")
trainer_state_path = os.path.join(output_dir, 'trainer_state.json')
self.state.save_to_json(trainer_state_path)

# Optionally push to hub if that option is enabled
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
self._push_from_checkpoint(output_dir)
28 changes: 13 additions & 15 deletions benchmarks/ppo/ppo_continual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@

import torch
import wandb as wb
from continual_ppo_trainer import (
ContinualPPOArguments,
ContinualPPOConfig,
ContinualPPOTrainer,
)
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
Expand All @@ -21,10 +16,15 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
setup_chat_format,
)
from trl import setup_chat_format

from benchmarks.dataloading import init_continual_dataset
from benchmarks.ppo.continual_ppo_trainer import (
ContinualPPOArguments,
ContinualPPOConfig,
ContinualPPOTrainer,
)


def main(
Expand Down Expand Up @@ -106,7 +106,9 @@ def main(
value_model_path = script_args.value_model_path
else:
model_path = os.path.join(training_args.output_dir, 'last')
value_model_path = os.path.join(training_args.output_dir, 'last', 'value_model')
value_model_path = os.path.join(
training_args.output_dir, 'last', 'value_model'
)
policy = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
trust_remote_code=model_args.trust_remote_code,
Expand All @@ -126,7 +128,7 @@ def main(
value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
from_tf=True, # or use `subfolder="safetensors"` if you saved a .safetensors file
from_tf=True, # or use `subfolder="safetensors"` if you saved a .safetensors file
)

# Build custom repository name for this task
Expand Down Expand Up @@ -173,9 +175,6 @@ def main(
peft_config=peft_config,
)

# if i == 0:
# trainer.save_model(os.path.join(training_args.output_dir, 'checkpoint-0'))

# Set current task in trainer for task-based logging
trainer.set_task(f'task_{i}')

Expand Down Expand Up @@ -208,9 +207,8 @@ def main(

value_model_dir = os.path.join(last_dir, 'value_model')
os.makedirs(value_model_dir, exist_ok=True)
value_model.save_pretrained(value_model_dir,
safe_serialization=False)

value_model.save_pretrained(value_model_dir, safe_serialization=False)

trainer.accelerator.wait_for_everyone()

if training_args.push_to_hub:
Expand All @@ -226,4 +224,4 @@ def main(
dataclass_types = (ContinualPPOArguments, ContinualPPOConfig, ModelConfig)
parser = TrlParser(dataclass_types)
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
main(script_args, training_args, model_args)
14 changes: 6 additions & 8 deletions benchmarks/ppo_ewc/continual_ppo_EWC_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def __init__(
# Store EWC-specific parameters
self.ewc_lambda = args.ewc_lambda

# Track if we're on the first task
is_first_task = ContinualPPOTrainer.current_task_index == 0
if is_first_task:
if self.current_task_index == 0:
# Initialize empty dictionaries for first task
ContinualPPOEWCTrainer.class_fisher_information = {}
ContinualPPOEWCTrainer.class_old_params = {}
Expand Down Expand Up @@ -775,15 +773,15 @@ def repeat_generator() -> DataLoader:
if self.ref_model is None and original_ref_model is not None:
print('Reference model was cleared during training - restoring')
self.ref_model = original_ref_model
ContinualPPOTrainer.class_ref_model = original_ref_model
self.class_ref_model = original_ref_model

# Ensure the class variable is updated
ContinualPPOTrainer.class_ref_model = self.ref_model
self.class_ref_model = self.ref_model
if self.is_deepspeed_enabled:
ContinualPPOTrainer.ds_wrapped_models = self.deepspeed
self.ds_wrapped_models = self.deepspeed
else:
ContinualPPOTrainer.ds_wrapped_models = self.model
ContinualPPOTrainer.policy_value_models = self.model
self.ds_wrapped_models = self.model
self.policy_value_models = self.model

def update_fisher_and_params(self) -> None:
"""Explicitly update the Fisher information and parameter values.
Expand Down
107 changes: 67 additions & 40 deletions benchmarks/ppo_ewc/ppo_EWC_continual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
setup_chat_format,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from benchmarks.dataloading import init_continual_dataset
from benchmarks.ppo_ewc.continual_ppo_EWC_trainer import (
Expand Down Expand Up @@ -52,15 +52,7 @@ def main(
quantization_config=quantization_config,
)

# Load main model and (optionally) reference model
model = str(training_args.sft_model_path)
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)

# Configure PEFT if needed
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
Expand All @@ -71,32 +63,11 @@ def main(
else:
ref_policy = None

# Load value model
value_model = None
if script_args.value_model_path:
value_model = AutoModelForSequenceClassification.from_pretrained(
script_args.value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
)

# Load tokenizer and set chat template if needed
tokenizer = AutoTokenizer.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

# EWC-specific: DDPT distributed setup
if script_args.ignore_bias_buffers:
policy._ddp_params_and_buffers_to_ignore = [
name
for name, buffer in policy.named_buffers()
if buffer.dtype == torch.bool
]

# Initialize continual dataset
continual_dataset: list[dict[str, Dataset]] = init_continual_dataset(
Expand All @@ -112,6 +83,7 @@ def main(
if '.' in clean_dataset_name:
clean_dataset_name = clean_dataset_name.split('.')[0]

print(f'Training PPO-EWC on {len(continual_dataset)} tasks')
# check if the reward models are present either in the path or in the hub
if training_args.reward_model_path is not None:
for i in range(len(continual_dataset)):
Expand All @@ -128,6 +100,44 @@ def main(

# Task Loop
for i, dataset in enumerate(continual_dataset):
# Load main model and (optionally) reference model
if i == 0:
model_path = training_args.sft_model_path
value_model_path = script_args.value_model_path
else:
model_path = os.path.join(training_args.output_dir, 'last')
value_model_path = os.path.join(
training_args.output_dir, 'last', 'value_model'
)
policy = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
# EWC-specific: DDPT distributed setup
if script_args.ignore_bias_buffers:
policy._ddp_params_and_buffers_to_ignore = [
name
for name, buffer in policy.named_buffers()
if buffer.dtype == torch.bool
]

# Load value model and policy model (main model)
try:
value_model = AutoModelForSequenceClassification.from_pretrained(
value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
)
except OSError:
# Maybe it was saved as safetensors?
value_model = AutoModelForSequenceClassification.from_pretrained(
value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
from_tf=True, # or use `subfolder="safetensors"` if you saved a .safetensors file
)

# Build custom repository name for this task
custom_repo_name = (
model.split('/')[-1] + '_' + clean_dataset_name + '_PPO_EWC_' + str(i)
Expand All @@ -141,6 +151,22 @@ def main(
training_args.reward_model_path + '_' + str(i), num_labels=1
)

for idx, _model in enumerate([policy, value_model, reward_model]):
# Align padding tokens between tokenizer and model
_model.config.pad_token_id = tokenizer.pad_token_id

# Use ChatML format if the tokenizer doesn't already have a chat template
if tokenizer.chat_template is None:
updated_model, updated_tokenizer = setup_chat_format(_model, tokenizer)
# Actually store the updated model
if idx == 0:
policy = updated_model
elif idx == 1:
value_model = updated_model
else:
reward_model = updated_model
tokenizer = updated_tokenizer

################
# Training and Evaluation
################
Expand Down Expand Up @@ -181,21 +207,22 @@ def main(
wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined]

# Save model checkpoint and optionally push
if not training_args.push_to_hub:
trainer.save_model(os.path.join(training_args.output_dir, 'last'))
else:
last_dir = os.path.join(training_args.output_dir, 'last')
policy.save_pretrained(last_dir)
tokenizer.save_pretrained(last_dir)

value_model_dir = os.path.join(last_dir, 'value_model')
os.makedirs(value_model_dir, exist_ok=True)
value_model.save_pretrained(value_model_dir, safe_serialization=False)

trainer.accelerator.wait_for_everyone()

if training_args.push_to_hub:
trainer.push_to_hub(
model_name=custom_repo_name,
dataset_name='Continual_PPO_EWC_' + clean_dataset_name + '_' + str(i),
)

# Clean up for next task - EWC specific
if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None:
# Remove reference to the DeepSpeed engine to allow proper cleanup
del trainer.deepspeed
# Free cached GPU memory
torch.cuda.empty_cache()

print('Training completed for all tasks!')


Expand Down
40 changes: 40 additions & 0 deletions jobs/ppo_ewc/ppo_ewc_cppo_multi_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
#SBATCH --job-name=aif-gen-ppo-ewc-cppo
#SBATCH --nodes=1 # Request 2 nodes
#SBATCH --gpus-per-node=h100:4 # Request 4 H100 GPUs per node
#SBATCH --ntasks-per-node=4 # One task per GPU
#SBATCH --cpus-per-task=6
#SBATCH --mem=64G
#SBATCH --time=24:00:00
#SBATCH --output=out/%x.%j.out # Include job name + job ID
#SBATCH --error=out/%x.%j.err # Include job name + job ID
#SBATCH --mail-type=ALL
#SBATCH --account=aip-rrabba
#SBATCH --mail-user=shahrad_m@icloud.com # Update with your email

source .env

dataset_name='CPPO-RL'

accelerate launch --config_file benchmarks/ppo/accelerate_configs/deepspeed_zero2.yaml \
benchmarks/ppo_ewc/ppo_EWC_continual.py \
--wandb_project "$dataset_name-post-May-19" \
--wandb_run_name "Qwen2-0.5B-PPO-EWC-${dataset_name}-multi-gpu" \
--dataset_name "$dataset_name" \
--sft_model_path Qwen/Qwen2-0.5B-Instruct \
--value_model_path LifelongAlignment/Qwen2-0.5B-Instruct_CPPO_REWARD_0 \
--reward_model_path LifelongAlignment/Qwen2-0.5B-Instruct_CPPO_REWARD \
--learning_rate 1.0e-6 \
--kl_coef 0.37 \
--cliprange 0.1 \
--response_length 256 \
--num_train_epochs 4 \
--gradient_checkpointing \
--per_device_train_batch_size 16 \
--logging_steps 10 \
--eval_strategy steps \
--eval_steps 200 \
--save_steps 300 \
--bf16 \
--output_dir "$SCRATCH/Qwen2-0.5B-PPO-EWC-${dataset_name}" \
--no_remove_unused_columns
Loading
Loading