From a68acf85e4f846ec0d6f63a3999b78e4b1a9acdb Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Tue, 7 Mar 2023 05:48:21 +0000 Subject: [PATCH 1/8] [WIP] MRT --- .../mrt_summarize_daily_cnn_t5.py | 143 +++++ .../t5_summarize_daily_cnn.py | 8 +- trlx/data/mrt_types.py | 65 +++ trlx/trainer/accelerate_mrt_trainer.py | 506 ++++++++++++++++++ trlx/utils/loading.py | 1 + 5 files changed, 719 insertions(+), 4 deletions(-) create mode 100644 examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py create mode 100644 trlx/data/mrt_types.py create mode 100644 trlx/trainer/accelerate_mrt_trainer.py diff --git a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py new file mode 100644 index 000000000..b03f46bcf --- /dev/null +++ b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py @@ -0,0 +1,143 @@ +from typing import List + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_ppo import PPOConfig + +try: + import evaluate +except ImportError: + raise ImportError( + "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" + ) + +config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=12, + checkpoint_interval=10000, + eval_interval=500, + pipeline="PromptPipeline", + trainer="AccelerateMRTTrainer", + tracker=None, + ), + model=ModelConfig( + model_path="google/flan-t5-small", + model_arch_type="seq2seq", + num_layers_unfrozen=2, + ), + tokenizer=TokenizerConfig( + tokenizer_path="google/flan-t5-small", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 1.0e-5, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=PPOConfig( + name="PPOConfig", + num_rollouts=512, + chunk_size=12, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=0.99, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ + "max_new_tokens": 100, + }, + gen_experience_kwargs={ + "max_new_tokens": 100, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + "top_p": 0.95, + }, + ), +) + + +meteor = evaluate.load("meteor") # use meteor as the reward function + +if __name__ == "__main__": + + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): + original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] + scores = [ + meteor.compute(predictions=[output.strip()], references=[original])["meteor"] + for (original, output) in zip(original_summaries, outputs) + ] + return scores + + dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") + + # take 20,000 samples from the training set as prompts for training + prompts = dataset["train"]["article"][0:1200] + summaries = dataset["train"]["highlights"][0:1200] + prompts = ["Summarize: " + prompt for prompt in prompts] + + # take 1,000 samples from the validation set as prompts for evaluation + val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] + val_summaries = dataset["validation"]["highlights"][0:1000] + + # make dictionary of prompts and labels to use for reward function + tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "right" + tokenizer.sep_token = "" + prompt_label = {} + max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + + for i in tqdm(range(len(prompts))): + key = tokenizer.decode( + tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + prompt_label[key.strip()] = summaries[i] + + for i in tqdm(range(len(val_prompts))): + key = tokenizer.decode( + tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + prompt_label[key.strip()] = val_summaries[i] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=val_prompts, + config=config, + ) diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 4c3a56758..8520e76cf 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -34,12 +34,12 @@ trainer="AcceleratePPOTrainer", ), model=ModelConfig( - model_path="google/flan-t5-large", + model_path="google/flan-t5-small", model_arch_type="seq2seq", num_layers_unfrozen=2, ), tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-large", + tokenizer_path="google/flan-t5-small", truncation_side="right", ), optimizer=OptimizerConfig( @@ -104,8 +104,8 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") # take 20,000 samples from the training set as prompts for training - prompts = dataset["train"]["article"][0:20000] - summaries = dataset["train"]["highlights"][0:20000] + prompts = dataset["train"]["article"][0:1200] + summaries = dataset["train"]["highlights"][0:1200] prompts = ["Summarize: " + prompt for prompt in prompts] # take 1,000 samples from the validation set as prompts for evaluation diff --git a/trlx/data/mrt_types.py b/trlx/data/mrt_types.py new file mode 100644 index 000000000..e3fd44f29 --- /dev/null +++ b/trlx/data/mrt_types.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass + +from torchtyping import TensorType + + +@dataclass +class MRTRLElement: + """ + :param query_tensor: The query tensor i.e. the prompt tokens. + Should be a long tensor. + :type query_tensor: torch.Tensor + + :param response_tensor: The response tensor i.e. the output tokens. + Should be a long tensor. + :type response_tensor: torch.Tensor + + :param logprobs: The log probabilities over all tokens in the vocabulary for + each token generated from the policy network + (i.e. the autoregressive model). + Should be a float tensor of same size as tokens, + with a dimension across the vocabulary. + :type logprobs: torch.Tensor + + :param values: The values for each token generated from the value network or value head. + Should be a float tensor of same size as tokens. + :type values: torch.Tensor + + :param rewards: The rewards for each token outputted in response. + Should be a float tensor of same size as tokens. + :type rewards: torch.Tensor + """ + + query_tensor: TensorType["query_size"] + response_tensor: TensorType["response_size"] + logprobs: TensorType["response_size", "vocab_size"] + values: TensorType["response_size"] + rewards: TensorType["response_size"] + + +@dataclass +class MRTRLBatch: + """ + A batched version of the PPORLElement. See PPORLElement for more details on individual fields. + + :param query_tensors: A batch of query tensors. Should be a long tensor. + :type query_tensors: torch.Tensor + + :param response_tensors: A batch of response tensors. Should be a long tensor. + :type response_tensors: torch.Tensor + + :param logprobs: A batch of log probabilities from policy + :type logprobs: torch.Tensor + + :param values: A batch of values from value network + :type values: torch.Tensor + + :param rewards: A batch of rewards + :type rewards: torch.Tensor + """ + + query_tensors: TensorType["batch_size", "query_size"] + response_tensors: TensorType["batch_size", "response_size"] + logprobs: TensorType["batch_size", "response_size", "vocab_size"] + values: TensorType["batch_size", "response_size"] + rewards: TensorType["batch_size", "response_size"] diff --git a/trlx/trainer/accelerate_mrt_trainer.py b/trlx/trainer/accelerate_mrt_trainer.py new file mode 100644 index 000000000..cb306db62 --- /dev/null +++ b/trlx/trainer/accelerate_mrt_trainer.py @@ -0,0 +1,506 @@ +import json +import os +import uuid +from time import time +from typing import Callable, List + +import ray +import torch +import torch.nn.functional as F +import transformers +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +import trlx.utils.logging as logging +from trlx.data.accelerate_base_datatypes import PromptBatch +from trlx.data.configs import TRLConfig +from trlx.data.mrt_types import MRTRLBatch, MRTRLElement +from trlx.models.modeling_ppo import ( + AdaptiveKLController, + AutoModelForCausalLMWithHydraValueHead, + AutoModelForSeq2SeqLMWithHydraValueHead, + FixedKLController, +) +from trlx.pipeline.offline_pipeline import PromptPipeline +from trlx.pipeline.ppo_pipeline import PPORolloutStorage +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.utils import Clock, infinite_dataloader +from trlx.utils.modeling import RunningMoments, logprobs_of_labels + +logger = logging.get_logger(__name__) + + +@register_trainer +class AccelerateMRTTrainer(AccelerateRLTrainer): + """PPO Accelerate Trainer""" + + reward_fn: Callable[[List[str], List[str], List[str]], List[float]] + tokenizer: AutoTokenizer + + def __init__(self, config: TRLConfig, **kwargs): + """PPO Accelerate Trainer initialization + + Args: + config: Config + """ + super().__init__(config, **kwargs) + + # Setup rollout logging + if config.train.rollout_logging_dir is not None: + self.log_rollouts = True + self.setup_rollout_logging(config) + else: + self.log_rollouts = False + + # Setup the rollout store + # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout + self.store = PPORolloutStorage(self.tokenizer.pad_token_id) + + # Create the rollout store dataloader (for batching up rollouts) + # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future + rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + + # Prepare multi-GPU acceleration + self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare( + self.model, self.opt, self.scheduler, rollout_loader + ) + + self.store.clear_history() # Clear the rollout store + + # Setup a reference model when hydra heads are not used + if not hasattr(self.model, "frozen_head"): + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() + + # Setup the KL controller + # This helps prevent large divergences in the controller (policy) + if config.method.target is not None: + self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon) + else: + self.kl_ctl = FixedKLController(config.method.init_kl_coef) + + # Create the parameters for the Hugging Face language model's generator + # method (that generates new tokens from a prompt). + # https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate + if config.model.model_arch_type == "seq2seq": + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + if config.method.gen_experience_kwargs is not None: + self.generate_experience_kwargs = dict( + config.method.gen_experience_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + self.generate_experience_kwargs = None + else: + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + if config.method.gen_experience_kwargs is not None: + self.generate_experience_kwargs = dict( + config.method.gen_experience_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + else: + self.generate_experience_kwargs = None + + # Setup stats tracker + self.running_moments = RunningMoments() + self.ref_mean = self.config.method.ref_mean + self.ref_std = self.config.method.ref_std + + def get_arch(self, config: TRLConfig): + """Get the model""" + model_class = AutoModelForCausalLMWithHydraValueHead + if config.model.model_arch_type == "seq2seq": + model_class = AutoModelForSeq2SeqLMWithHydraValueHead + + from_fn = model_class.from_pretrained + # backward-compat: Try to create a randomly initialized architecture from a config + if issubclass(type(config.model.model_path), transformers.PretrainedConfig): + from_fn = model_class.from_config + + return from_fn( + config.model.model_path, + num_layers_unfrozen=config.model.num_layers_unfrozen, + ) + + def loss(self, batch: MRTRLBatch): + """Forward pass & loss + + Args: + batch: Previous batch of episodes + """ + # Move `batch` data to `accelerator` device + query_tensors = batch.query_tensors.to(self.accelerator.device) + response_tensors = batch.response_tensors.to(self.accelerator.device) + old_logprobs = batch.logprobs.to(self.accelerator.device) + old_values = batch.values.to(self.accelerator.device) + old_rewards = batch.rewards.to(self.accelerator.device) + response_length = old_rewards.shape[1] + + advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) + + if self.config.model.model_arch_type == "seq2seq": + input_ids = query_tensors + decoder_input_ids = response_tensors + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + decoder_attention_mask = ( + decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + ) + decoder_attention_mask[:, 0] = 1 + + # Forward pass + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + logits = outputs.logits + values_pred = outputs.value + logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) + mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + start = 0 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + mask[:, start:end], + ) + else: + tokens = torch.cat((query_tensors, response_tensors), dim=1) + attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) + outputs = self.model(tokens, attention_mask, return_dict=True) + logits = outputs.logits + values_pred = outputs.value + values_pred = values_pred[:, :-1] + logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + + start = query_tensors.shape[1] - 1 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + attention_mask[:, start:end], + ) + + loss, stats = self.config.method.loss( + logprobs=logprobs, + values=values_pred, + old_logprobs=old_logprobs, + old_values=old_values, + advantages=advantages, + returns=returns, + mask=mask, + ) + self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats + return loss, stats + + def setup_rollout_logging(self, config): + # Make rollout logging dir for this run and store config + exists = os.path.exists(config.train.rollout_logging_dir) + isdir = os.path.isdir(config.train.rollout_logging_dir) + assert exists and isdir + + self.run_id = f"run-{uuid.uuid4()}" + self.rollout_logging_dir = os.path.join(config.train.rollout_logging_dir, self.run_id) + os.mkdir(self.rollout_logging_dir) + + with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f: + f.write(json.dumps(config.to_dict(), indent=2)) + + def post_epoch_callback(self): + """Post epoch callback + + Clears the store and creates `num_rollouts` new episodes. + """ + if self.log_rollouts: + self.store.export_history(location=self.rollout_logging_dir) + self.store.clear_history() + # Collect more rollouts for training + self.make_experience(self.config.method.num_rollouts, self.iter_count) + + def post_backward_callback(self): + self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size) + + def prepare_learning(self): + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader) + self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + + self.n_updates_per_batch = self.config.method.ppo_epochs + self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def add_prompt_pipeline(self, pipeline: PromptPipeline): + """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + self.prompt_iterator = infinite_dataloader(prompt_dataloader) + + def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: + """Make experiences + + Takes `chunk_size` number of prompts from `prompt_iterator`, samples + from the model and then computes the KL against a reference model. Finally it + then appends PPOElements to trainer's `store`. + + Args: + num_rollouts: Number of rollouts to generate + iter_count: Total number of updates run (i.e. number of updates run for all batches & epochs) + """ + logger.info("Collecting rollouts") + tbar = logging.tqdm( + total=num_rollouts, + disable=os.environ.get("RANK", 0) != "0", + desc=f"[rollout 0 / {num_rollouts}]", + # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress + # bars (e.g. loss progress in trainers) + position=logging.get_verbosity() >= logging.WARNING, + # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels + leave=logging.get_verbosity() < logging.WARNING, + ) + + ppo_rl_elements = [] + stats = {} + clock = Clock() + + while len(ppo_rl_elements) < num_rollouts: + # Get next batch in prompt dataset + batch: PromptBatch = next(self.prompt_iterator) + + exp_generate_time = time() + + # Generate samples from the language model (similar to using HuggingFace `generate` method) + samples = self.generate(**batch) + stats["time/exp_generate"] = time() - exp_generate_time + + prompt_tensors = batch.input_ids + device = samples.device + + prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) + padded_samples = self.accelerator.pad_across_processes( + samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + padded_prompts = self.accelerator.pad_across_processes( + prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + gathered_samples = self.accelerator.gather(padded_samples) + gathered_prompts = self.accelerator.gather(padded_prompts) + gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) + + if self.accelerator.is_main_process: + all_str_samples, all_str_prompts, all_str_outputs = self.decode( + gathered_prompts, gathered_samples, gathered_prompt_sizes + ) + + exp_score_time = time() + all_scores = torch.tensor( + self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + ), + dtype=torch.float, + device=device, + ) + stats["time/exp_score"] = time() - exp_score_time + + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + else: + all_scores = None + + if torch.distributed.is_initialized(): + scores = torch.empty(len(samples), device=device) + torch.distributed.scatter(scores, all_scores) + else: + scores = torch.tensor(all_scores[0]) + + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) + + # Pad the sample outputs + outputs = self.tokenizer(str_outputs).input_ids + if self.config.model.model_arch_type == "seq2seq": + # add to the start of the output + for i in range(len(outputs)): + outputs[i] = [self.tokenizer.pad_token_id] + outputs[i] + + outputs = list(map(torch.LongTensor, outputs)) + maxsize = max(map(len, outputs)) + outputs = [ + F.pad( + output, + (0, maxsize - len(output)), + value=self.tokenizer.pad_token_id, + ) + for output in outputs + ] + sample_outputs = torch.vstack(outputs).to(device) + + # store statistics of the initial rollout as reference + if self.ref_mean is None: + self.ref_mean, self.ref_std = scores.mean(), scores.std() + all_scores_mean, all_scores_std = self.running_moments.update(scores) + stats["exp_scores/mean"] = all_scores_mean + stats["exp_scores/std"] = all_scores_std + stats["exp_scores/running_mean"] = self.running_moments.mean + stats["exp_scores/running_std"] = self.running_moments.std + + if self.config.method.scale_reward == "running": + scores /= self.running_moments.std + elif self.config.method.scale_reward == "ref": + scores /= self.ref_std + + clip_reward = self.config.method.cliprange_reward + if clip_reward: + scores = torch.clip(scores, -clip_reward, clip_reward) + + # Precompute logprobs, values + if self.config.model.model_arch_type == "seq2seq": + attention_mask = batch.attention_mask.to(device) + prompt_tensors = batch.input_ids.to(device) + decoder_attention_mask = sample_outputs.not_equal(self.tokenizer.pad_token_id) + decoder_attention_mask[:, 0] = 1 + with torch.no_grad(): + outputs = self.model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + ) + logits = outputs.logits + values = outputs.value + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ).logits + else: + all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) + attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens, + attention_mask=attention_mask, + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + else: + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + + n_samples: int = samples.shape[0] + logprobs = logprobs.cpu() + ref_logprobs = ref_logprobs.cpu() + prompt_tensors = prompt_tensors.cpu() + sample_outputs = sample_outputs.cpu() + + # Estimate the KL divergence between the model and reference model + if self.config.model.model_arch_type == "seq2seq": + values = values.cpu()[:, :-1] + start = 0 + + # Get the number of non-padding tokens for each sample + # This assumes all padding is on the right side + padding_token: int = 0 + ends = (sample_outputs[:, start:] != padding_token).sum(1) + + # Get the logprobs and values, for tokens that are not padding + # or beginning of sequences tokens. These are from the model + # (not the reference model) + all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + + kl_divergence_estimate: List[torch.Tensor] = [ + -self.kl_ctl.value + * ( + logprobs[sample_idx, start : ends[sample_idx]] + - ref_logprobs[sample_idx, start : ends[sample_idx]] + ) + for sample_idx in range(n_samples) + ] + + # Else if not seq2seq (i.e. causal) + else: + values = values.cpu()[:, :-1] + start = prompt_tensors.shape[1] - 1 + ends = start + attention_mask[:, start:].sum(1) + all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + + kl_divergence_estimate = -self.kl_ctl.value * (logprobs - ref_logprobs) + kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] + + rollout_count = 0 + + for sample_idx in range(n_samples): + sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx] + + if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0: + continue + + rewards = sample_kl_divergence_estimate + rewards[-1] += scores[sample_idx].cpu() + + ppo_rl_elements.append( + MRTRLElement( + query_tensor=prompt_tensors[sample_idx], + response_tensor=sample_outputs[sample_idx], + logprobs=all_logprobs[sample_idx], + values=all_values[sample_idx], + rewards=rewards, + ) + ) + + rollout_count += 1 + exp_time = clock.tick() + tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") + tbar.update(min(rollout_count, num_rollouts)) + tbar.close() + + stats["kl_ctl_value"] = self.kl_ctl.value + stats["time/exp"] = exp_time + + if not ray.is_initialized(): + self.accelerator.log(stats, step=iter_count) + + # Push samples and rewards to trainer's rollout storage + self.push_to_store(ppo_rl_elements) diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index cdb49926c..d0fff30c3 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -7,6 +7,7 @@ # Register load trainers via module import from trlx.trainer import _TRAINERS, register_trainer from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer +from trlx.trainer.accelerate_mrt_trainer import AccelerateMRTTrainer from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer From db629e336cc7bdedf41771ff70a1c8d793e19a11 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Wed, 15 Mar 2023 07:13:31 +0000 Subject: [PATCH 2/8] [WIP] make_experience working --- .../mrt_summarize_daily_cnn_t5.py | 68 ++++++--- trlx/models/modeling_mrt.py | 142 ++++++++++++++++++ trlx/trainer/accelerate_mrt_trainer.py | 120 ++++++++------- 3 files changed, 259 insertions(+), 71 deletions(-) create mode 100644 trlx/models/modeling_mrt.py diff --git a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py index b03f46bcf..7357496d2 100644 --- a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py +++ b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py @@ -13,7 +13,7 @@ TrainConfig, TRLConfig, ) -from trlx.models.modeling_ppo import PPOConfig +from trlx.models.modeling_mrt import MRTConfig try: import evaluate @@ -27,7 +27,7 @@ seq_length=612, epochs=100, total_steps=100000, - batch_size=12, + batch_size=4, checkpoint_interval=10000, eval_interval=500, pipeline="PromptPipeline", @@ -40,8 +40,8 @@ num_layers_unfrozen=2, ), tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-small", - truncation_side="right", + tokenizer_path="google/flan-t5-small", #### change to reasonable value + truncation_side="right", # what is this? ), optimizer=OptimizerConfig( name="adamw", @@ -59,36 +59,64 @@ "eta_min": 1.0e-6, }, ), - method=PPOConfig( - name="PPOConfig", + method=MRTConfig( + name="MRTConfig", + # n_updates_per_batch=1, #### MRT num_rollouts=512, - chunk_size=12, + chunk_size=4, ppo_epochs=4, - init_kl_coef=0.05, - target=6, - horizon=10000, - gamma=0.99, - lam=0.95, - cliprange=0.2, - cliprange_value=0.2, - vf_coef=1.0, + # init_kl_coef=0.05, + # target=6, + # horizon=10000, + # gamma=0.99, + # lam=0.95, + # cliprange=0.2, + # cliprange_value=0.2, + # vf_coef=1.0, + num_candidates=16, scale_reward=None, ref_mean=None, ref_std=None, cliprange_reward=10, - gen_kwargs={ + gen_kwargs={ # for evaluation "max_new_tokens": 100, + # TODO: what should the defaults here be }, - gen_experience_kwargs={ + gen_experience_kwargs={ # for rollouts "max_new_tokens": 100, - "do_sample": True, + "num_beams": 16, # should be same as nb_candidates + "num_return_sequences": 16, # should be same as nb_candidates + "do_sample": False, "temperature": 1.0, - "top_k": 50, - "top_p": 0.95, + # "top_k": 50, + # "top_p": 0.95, }, ), ) + # gen_kwargs = { + # "min_length":-1, + # "top_k": config['top_k'], + # "top_p": 1.0, + # "temperature": config["temperature"], + # "do_sample": config['do_sample'], + # "num_beams": config['num_beams'], + # "max_length": config['max_length'], + # # "pad_token_id": model.eos_token_id, + # "num_return_sequences": config['candidate_size'], + # } + # eval_kwargs = { + # # "early_stopping": True, + # # "length_penalty": 2.0, + # "min_length":-1, + # "top_k": 0.0, + # # "top_p": 1.0, + # "do_sample": False, + # "num_beams": config['eval_num_beams'], + # # "no_repeat_ngram_size": 3, + # "max_length": config['max_length'], + # } + meteor = evaluate.load("meteor") # use meteor as the reward function diff --git a/trlx/models/modeling_mrt.py b/trlx/models/modeling_mrt.py new file mode 100644 index 000000000..e8fc3e922 --- /dev/null +++ b/trlx/models/modeling_mrt.py @@ -0,0 +1,142 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torchtyping import TensorType + +from trlx.data.method_configs import MethodConfig, register_method +from trlx.models.modeling_base import PreTrainedModelWrapper +from trlx.utils.modeling import ( + flatten_dict, + get_tensor_stats, +) + + +@dataclass +@register_method +class MRTConfig(MethodConfig): + """ + Config for MRT method + + :param ppo_epochs: Number of updates per batch + :type ppo_epochs: int + + :param num_rollouts: Number of experiences to observe before learning + :type num_rollouts: int + + :param init_kl_coef: Initial value for KL coefficient + :type init_kl_coef: float + + :param target: Target value for KL coefficient + :type target: float + + :param horizon: Number of steps for KL coefficient to reach target + :type horizon: int + + :param gamma: Discount factor + :type gamma: float + + :param lam: GAE lambda + :type lam: float + + :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange) + :type cliprange: float + + :param cliprange_value: Clipping range for predicted values + (observed values - cliprange_value, observed values + cliprange_value) + :type cliprange_value: float + + :param vf_coef: Value loss scale w.r.t policy loss + :type vf_coef: float + + :param gen_kwargs: Additioanl kwargs for the generation + :type gen_kwargs: Dict[str, Any] + + :param gen_experience_kwargs: if this is not None, then the experience is generated using this + :type gen_experience_kwargs: Dict[str, Any] + """ + + ppo_epochs: int + num_rollouts: int + chunk_size: int + # init_kl_coef: float + # target: float + # horizon: int + # gamma: float + # lam: float + # cliprange: float + # cliprange_value: float + # vf_coef: float + num_candidates: int + scale_reward: Optional[str] + ref_mean: Optional[float] + ref_std: Optional[float] + cliprange_reward: float + gen_kwargs: dict + gen_experience_kwargs: Optional[dict] = None + + def loss( + self, + logprobs: TensorType["batch_size", "response_size"], + values: TensorType["batch_size", "response_size"], + old_logprobs: TensorType["batch_size", "response_size"], + old_values: TensorType["batch_size", "response_size"], + advantages: TensorType["batch_size", "response_size"], + returns: TensorType["batch_size", "response_size"], + mask: TensorType["batch_size", "response_size"], + ): + """PPO objective function. + References: + - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html + """ + values_clipped = torch.clamp( + values, + old_values - self.cliprange_value, + old_values + self.cliprange_value, + ) + n = mask.sum() + + vf_loss1 = (values - returns) ** 2 + vf_loss2 = (values_clipped - returns) ** 2 + vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n + vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n + + log_ratio = (logprobs - old_logprobs) * mask + ratio = torch.exp(log_ratio) + # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html + with torch.no_grad(): + approx_kl = torch.mean((ratio - 1) - log_ratio) + + pg_loss1 = -advantages * ratio + pg_loss2 = -advantages * torch.clamp( + ratio, + 1.0 - self.cliprange, + 1.0 + self.cliprange, + ) + pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n + pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n + + loss = pg_loss + self.vf_coef * vf_loss + + stats = dict( + losses=dict( + total_loss=loss.item(), + policy_loss=pg_loss.item(), + value_loss=vf_loss.item(), + ), + values=dict( + get_tensor_stats(values, mask, n), + values_error=torch.sum(((values - returns) * mask) ** 2) / n, + clipfrac=vf_clipfrac, + ), + old_values=get_tensor_stats(old_values, mask, n), + returns=get_tensor_stats(returns, mask, n), + policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()), + ratio=(ratio * mask).sum() / n, + padding_percentage=n / mask.numel(), + ) + + return loss, flatten_dict(stats) diff --git a/trlx/trainer/accelerate_mrt_trainer.py b/trlx/trainer/accelerate_mrt_trainer.py index cb306db62..f958d2843 100644 --- a/trlx/trainer/accelerate_mrt_trainer.py +++ b/trlx/trainer/accelerate_mrt_trainer.py @@ -25,7 +25,7 @@ from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer -from trlx.utils import Clock, infinite_dataloader +from trlx.utils import Clock from trlx.utils.modeling import RunningMoments, logprobs_of_labels logger = logging.get_logger(__name__) @@ -76,10 +76,10 @@ def __init__(self, config: TRLConfig, **kwargs): # Setup the KL controller # This helps prevent large divergences in the controller (policy) - if config.method.target is not None: - self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon) - else: - self.kl_ctl = FixedKLController(config.method.init_kl_coef) + # if config.method.target is not None: + # self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon) + # else: + # self.kl_ctl = FixedKLController(config.method.init_kl_coef) # Create the parameters for the Hugging Face language model's generator # method (that generates new tokens from a prompt). @@ -204,6 +204,8 @@ def loss(self, batch: MRTRLBatch): returns=returns, mask=mask, ) + + # TODO update this self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats return loss, stats @@ -232,7 +234,8 @@ def post_epoch_callback(self): self.make_experience(self.config.method.num_rollouts, self.iter_count) def post_backward_callback(self): - self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size) + ... + # self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size) def prepare_learning(self): eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) @@ -246,8 +249,8 @@ def prepare_learning(self): def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) - prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) - self.prompt_iterator = infinite_dataloader(prompt_dataloader) + self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + self.prompt_iterator = iter(self.prompt_dataloader) def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: """Make experiences @@ -272,22 +275,35 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq leave=logging.get_verbosity() < logging.WARNING, ) - ppo_rl_elements = [] + mrt_rl_elements = [] stats = {} clock = Clock() - while len(ppo_rl_elements) < num_rollouts: - # Get next batch in prompt dataset - batch: PromptBatch = next(self.prompt_iterator) + while len(mrt_rl_elements) < num_rollouts: + # Get next batch in prompt dataset and refresh if exhausted + # TOOD (jon-tow): Make `prompt_dataloader` a cyclic/infinite DataLoader to not require manually + # "refreshing" the contents of the `prompt_iterator` + try: + batch: PromptBatch = next(self.prompt_iterator) + except StopIteration: + self.prompt_iterator = iter(self.prompt_dataloader) + batch = next(self.prompt_iterator) exp_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) + # For MRT, this should generate num_candidates samples for each prompt in the batch samples = self.generate(**batch) + device = samples.device + + # Expand queries and mask + copied_idxs = torch.tensor([i for i in range(batch.input_ids.shape[0]) for _ in range(self.config.method.num_candidates)], device=device) + batch.input_ids = torch.index_select(batch.input_ids, 0, copied_idxs) # [batch_size * candidate_size, query_length] + batch.attention_mask = torch.index_select(batch.attention_mask, 0, copied_idxs) # [batch_size * candidate_size, query_length] + stats["time/exp_generate"] = time() - exp_generate_time prompt_tensors = batch.input_ids - device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) padded_samples = self.accelerator.pad_across_processes( @@ -325,7 +341,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = torch.empty(len(samples), device=device) torch.distributed.scatter(scores, all_scores) else: - scores = torch.tensor(all_scores[0]) + scores = all_scores[0].clone() # torch.tensor(all_scores[0]) str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) @@ -398,34 +414,36 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: - all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) - attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) - with torch.no_grad(): - logits, *_, values = self.model( - all_tokens, - attention_mask=attention_mask, - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head"): - ref_logits = self.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, - return_dict=True, - ).logits - else: - ref_logits = self.ref_model( - all_tokens, - attention_mask=attention_mask, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) + assert False + # else: + # all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) + # attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + # with torch.no_grad(): + # logits, *_, values = self.model( + # all_tokens, + # attention_mask=attention_mask, + # ) + # # TODO(dahoas): When hydra model works need to also support generation on hydra head + # if hasattr(self.model, "frozen_head"): + # ref_logits = self.model.forward_hydra( + # all_tokens, + # attention_mask=attention_mask, + # return_dict=True, + # ).logits + # else: + # ref_logits = self.ref_model( + # all_tokens, + # attention_mask=attention_mask, + # return_dict=True, + # ).logits + # ref_logits = ref_logits.to(device) if self.config.model.model_arch_type == "seq2seq": logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) - else: - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + # else: + # logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) + # ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) n_samples: int = samples.shape[0] logprobs = logprobs.cpu() @@ -450,7 +468,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] kl_divergence_estimate: List[torch.Tensor] = [ - -self.kl_ctl.value + 1.0 #-self.kl_ctl.value * ( logprobs[sample_idx, start : ends[sample_idx]] - ref_logprobs[sample_idx, start : ends[sample_idx]] @@ -459,15 +477,15 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ] # Else if not seq2seq (i.e. causal) - else: - values = values.cpu()[:, :-1] - start = prompt_tensors.shape[1] - 1 - ends = start + attention_mask[:, start:].sum(1) - all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + # else: + # values = values.cpu()[:, :-1] + # start = prompt_tensors.shape[1] - 1 + # ends = start + attention_mask[:, start:].sum(1) + # all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + # all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - kl_divergence_estimate = -self.kl_ctl.value * (logprobs - ref_logprobs) - kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] + # kl_divergence_estimate = 1.0 * (logprobs - ref_logprobs) + # kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] rollout_count = 0 @@ -480,7 +498,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards = sample_kl_divergence_estimate rewards[-1] += scores[sample_idx].cpu() - ppo_rl_elements.append( + mrt_rl_elements.append( MRTRLElement( query_tensor=prompt_tensors[sample_idx], response_tensor=sample_outputs[sample_idx], @@ -492,15 +510,15 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count += 1 exp_time = clock.tick() - tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") + tbar.set_description(f"[rollout {len(mrt_rl_elements)} / {num_rollouts}]") tbar.update(min(rollout_count, num_rollouts)) tbar.close() - stats["kl_ctl_value"] = self.kl_ctl.value + # stats["kl_ctl_value"] = self.kl_ctl.value stats["time/exp"] = exp_time if not ray.is_initialized(): self.accelerator.log(stats, step=iter_count) # Push samples and rewards to trainer's rollout storage - self.push_to_store(ppo_rl_elements) + self.push_to_store(mrt_rl_elements) From 7f2478e53823452e64559d6056e5b77017943d60 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Thu, 16 Mar 2023 02:15:19 +0000 Subject: [PATCH 3/8] [WIP] MRT make_experience working with batch --- .../mrt_summarize_daily_cnn_t5.py | 1 + trlx/data/mrt_types.py | 20 +-- trlx/models/modeling_mrt.py | 160 ++++++++++++------ trlx/pipeline/mrt_pipeline.py | 80 +++++++++ trlx/trainer/accelerate_base_trainer.py | 6 +- trlx/trainer/accelerate_mrt_trainer.py | 71 ++++---- 6 files changed, 247 insertions(+), 91 deletions(-) create mode 100644 trlx/pipeline/mrt_pipeline.py diff --git a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py index 7357496d2..12b68625f 100644 --- a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py +++ b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py @@ -74,6 +74,7 @@ # cliprange_value=0.2, # vf_coef=1.0, num_candidates=16, + ce_loss_weight=0.0, scale_reward=None, ref_mean=None, ref_std=None, diff --git a/trlx/data/mrt_types.py b/trlx/data/mrt_types.py index e3fd44f29..ac1ffe46c 100644 --- a/trlx/data/mrt_types.py +++ b/trlx/data/mrt_types.py @@ -30,11 +30,11 @@ class MRTRLElement: :type rewards: torch.Tensor """ - query_tensor: TensorType["query_size"] - response_tensor: TensorType["response_size"] - logprobs: TensorType["response_size", "vocab_size"] - values: TensorType["response_size"] - rewards: TensorType["response_size"] + query_tensor: TensorType["num_candidates", "query_size"] + response_tensor: TensorType["num_candidates", "response_size"] + logprobs: TensorType["num_candidates", "response_size", "vocab_size"] + values: TensorType["num_candidates", "response_size"] + rewards: TensorType["num_candidates", "response_size"] @dataclass @@ -58,8 +58,8 @@ class MRTRLBatch: :type rewards: torch.Tensor """ - query_tensors: TensorType["batch_size", "query_size"] - response_tensors: TensorType["batch_size", "response_size"] - logprobs: TensorType["batch_size", "response_size", "vocab_size"] - values: TensorType["batch_size", "response_size"] - rewards: TensorType["batch_size", "response_size"] + query_tensors: TensorType["batch_size", "num_candidates", "query_size"] + response_tensors: TensorType["batch_size", "num_candidates", "response_size"] + logprobs: TensorType["batch_size", "num_candidates", "response_size", "vocab_size"] + values: TensorType["batch_size", "num_candidates", "response_size"] + rewards: TensorType["batch_size", "num_candidates", "response_size"] diff --git a/trlx/models/modeling_mrt.py b/trlx/models/modeling_mrt.py index e8fc3e922..2944b4e15 100644 --- a/trlx/models/modeling_mrt.py +++ b/trlx/models/modeling_mrt.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from torchtyping import TensorType from trlx.data.method_configs import MethodConfig, register_method @@ -71,6 +72,7 @@ class MRTConfig(MethodConfig): # cliprange_value: float # vf_coef: float num_candidates: int + ce_loss_weight: float scale_reward: Optional[str] ref_mean: Optional[float] ref_std: Optional[float] @@ -81,62 +83,124 @@ class MRTConfig(MethodConfig): def loss( self, logprobs: TensorType["batch_size", "response_size"], - values: TensorType["batch_size", "response_size"], - old_logprobs: TensorType["batch_size", "response_size"], - old_values: TensorType["batch_size", "response_size"], - advantages: TensorType["batch_size", "response_size"], - returns: TensorType["batch_size", "response_size"], + rewards: TensorType["batch_size", "response_size"], mask: TensorType["batch_size", "response_size"], ): - """PPO objective function. + """MRT objective function. References: - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html """ - values_clipped = torch.clamp( - values, - old_values - self.cliprange_value, - old_values + self.cliprange_value, - ) - n = mask.sum() - - vf_loss1 = (values - returns) ** 2 - vf_loss2 = (values_clipped - returns) ** 2 - vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n - vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n - - log_ratio = (logprobs - old_logprobs) * mask - ratio = torch.exp(log_ratio) - # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html - with torch.no_grad(): - approx_kl = torch.mean((ratio - 1) - log_ratio) - - pg_loss1 = -advantages * ratio - pg_loss2 = -advantages * torch.clamp( - ratio, - 1.0 - self.cliprange, - 1.0 + self.cliprange, + + loss = torch.tensor(0.0) + costs = 1 - rewards + + # Reward component + if self.ce_loss_weight < 1.0: # if ce_loss_weight is 1.0, then we only use the ce loss + # We make the assumption here that rewards are scaled to [0,1] + # lengths = response_masks.sum(dim=-1).float() + lengths = mask.sum(dim=-1).float() + +# model_outputs = self.model( +# input_ids=queries, +# decoder_input_ids=responses, # response tokens are already shifted right (start token = pad token) +# attention_mask=query_masks +# return_dict=True) +# , + avg_scores = logprobs.sum(dim=-1) / lengths + + # [batch_size, candidate_size] + avg_scores = avg_scores.view(-1, self.num_candidates) + costs = costs.view(-1, self.num_candidates) + + probs = F.softmax(avg_scores, dim=1).squeeze(-1) + loss = (probs * costs).sum() + + # Cross entropy component + ce_loss = torch.tensor(0.0) + if self.ce_loss_weight > 0.0: + assert False, 'ce_loss_weight should be 0.0' + # if parallel_mask is not None: + # queries = queries[parallel_mask] + # query_masks = query_masks[parallel_mask] + # refs = refs[parallel_mask] + # ref_masks = ref_masks[parallel_mask] + + # # We should compute the cross entropy with the reference response and not with the generated response + # model_outputs = self.model( + # input_ids=queries, + # decoder_input_ids=shift_tokens_right(refs, self.model.config.pad_token_id, self.model.config.decoder_start_token_id), + # attention_mask=query_masks, + # return_dict=True) + # ce_loss = F.cross_entropy( + # model_outputs.logits.reshape(-1, model_outputs.logits.size(-1)), + # refs.reshape(-1), + # ignore_index=self.model.config.pad_token_id + # ) + + combined_loss = self.ce_loss_weight * ce_loss + (1 - self.ce_loss_weight) * loss + + stats = dict( + loss=dict(combined_loss=combined_loss, ce_loss=ce_loss, loss=loss, costs=costs), ) - pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n - pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n + # stats = utils.apply_to_sample(lambda t: t.detach().cpu(), stats) + + return combined_loss, flatten_dict(stats) + + + +""" + def loss(self, rewards, queries, responses, query_masks, response_masks, refs, ref_masks, parallel_mask=None): + loss = torch.tensor(0.0) + costs = 1 - rewards + + # Reward component + if self.params['ce_loss_weight'] < 1.0: # if ce_loss_weight is 1.0, then we only use the ce loss + # We make the assumption here that rewards are scaled to [0,1] + lengths = response_masks.sum(dim=-1).float() + + model_outputs = self.model( + input_ids=queries, + decoder_input_ids=responses, # response tokens are already shifted right (start token = pad token) + attention_mask=query_masks, + return_dict=True) + + logprobs = logprobs_from_logits(model_outputs.logits[:,:-1,:], responses[:, 1:], mask=response_masks[:, 1:]) + avg_scores = logprobs.sum(dim=-1) / lengths + + # [batch_size, candidate_size] + avg_scores = avg_scores.view(-1, self.params['candidate_size']) + costs = costs.view(-1, self.params['candidate_size']) + + probs = F.softmax(avg_scores, dim=1).squeeze(-1) + loss = (probs * costs).sum() + + # Cross entropy component + ce_loss = torch.tensor(0.0) + if self.params['ce_loss_weight'] > 0.0: + if parallel_mask is not None: + queries = queries[parallel_mask] + query_masks = query_masks[parallel_mask] + refs = refs[parallel_mask] + ref_masks = ref_masks[parallel_mask] + + # We should compute the cross entropy with the reference response and not with the generated response + model_outputs = self.model( + input_ids=queries, + decoder_input_ids=shift_tokens_right(refs, self.model.config.pad_token_id, self.model.config.decoder_start_token_id), + attention_mask=query_masks, + return_dict=True) + ce_loss = F.cross_entropy( + model_outputs.logits.reshape(-1, model_outputs.logits.size(-1)), + refs.reshape(-1), + ignore_index=self.model.config.pad_token_id + ) - loss = pg_loss + self.vf_coef * vf_loss + combined_loss = self.params['ce_loss_weight'] * ce_loss + (1 - self.params['ce_loss_weight']) * loss stats = dict( - losses=dict( - total_loss=loss.item(), - policy_loss=pg_loss.item(), - value_loss=vf_loss.item(), - ), - values=dict( - get_tensor_stats(values, mask, n), - values_error=torch.sum(((values - returns) * mask) ** 2) / n, - clipfrac=vf_clipfrac, - ), - old_values=get_tensor_stats(old_values, mask, n), - returns=get_tensor_stats(returns, mask, n), - policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()), - ratio=(ratio * mask).sum() / n, - padding_percentage=n / mask.numel(), + loss=dict(combined_loss=combined_loss, ce_loss=ce_loss, loss=loss, costs=costs), ) + stats = utils.apply_to_sample(lambda t: t.detach().cpu(), stats) - return loss, flatten_dict(stats) + return combined_loss, flatten_dict(stats) +""" diff --git a/trlx/pipeline/mrt_pipeline.py b/trlx/pipeline/mrt_pipeline.py new file mode 100644 index 000000000..bab21356f --- /dev/null +++ b/trlx/pipeline/mrt_pipeline.py @@ -0,0 +1,80 @@ +import json +import os +import time +from typing import Iterable + +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader + +from trlx.data.mrt_types import MRTRLBatch, MRTRLElement +from trlx.pipeline import BaseRolloutStore + + +class MRTRolloutStorage(BaseRolloutStore): + """ + Rollout storage for training MRT + """ + + def __init__(self, pad_token_id): + super().__init__() + + self.pad_token_id = pad_token_id + self.history: Iterable[MRTRLElement] = [None] + + def push(self, exps: Iterable[MRTRLElement]): + self.history += exps + + def clear_history(self): + self.history = [] + + def export_history(self, location: str): + assert os.path.exists(location) + + fpath = os.path.join(location, f"epoch-{str(time.time())}.json") + + def exp_to_dict(exp): + {k: v.cpu().tolist() for k, v in exp.__dict__.items()} + + data = [exp_to_dict(exp) for exp in self.history] + with open(fpath, "w") as f: + f.write(json.dumps(data, indent=2)) + + def __getitem__(self, index: int) -> MRTRLElement: + return self.history[index] + + def __len__(self) -> int: + return len(self.history) + + def create_loader( + self, + batch_size: int, + shuffle: bool, + ) -> DataLoader: + def collate_fn(elems: Iterable[MRTRLElement]): + return MRTRLBatch( + # Left padding of already left-padded queries + pad_sequence( + [elem.query_tensor.flip(0) for elem in elems], + padding_value=self.pad_token_id, + batch_first=True, + ).flip(1), + # Right pad the rest, to have a single horizontal query/response split + pad_sequence( + [elem.response_tensor for elem in elems], + padding_value=self.pad_token_id, + batch_first=True, + ), + pad_sequence( + [elem.logprobs for elem in elems], + padding_value=0.0, + batch_first=True, + ), + pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True), + pad_sequence( + [elem.rewards for elem in elems], + padding_value=0.0, + batch_first=True, + ), + ) + + return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 0e92efc90..587215074 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -451,8 +451,10 @@ def learn(self): # noqa: C901 state = json.load(f) self.iter_count = state["iter_count"] else: - results = self.evaluate() - self.accelerator.log(results, step=self.iter_count) + # TODO: no eval for now: + # results = self.evaluate() + # self.accelerator.log(results, step=self.iter_count) + ... tbar = logging.tqdm( initial=self.iter_count, diff --git a/trlx/trainer/accelerate_mrt_trainer.py b/trlx/trainer/accelerate_mrt_trainer.py index f958d2843..6e39c5701 100644 --- a/trlx/trainer/accelerate_mrt_trainer.py +++ b/trlx/trainer/accelerate_mrt_trainer.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -22,7 +23,7 @@ FixedKLController, ) from trlx.pipeline.offline_pipeline import PromptPipeline -from trlx.pipeline.ppo_pipeline import PPORolloutStorage +from trlx.pipeline.mrt_pipeline import MRTRolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer from trlx.utils import Clock @@ -55,7 +56,7 @@ def __init__(self, config: TRLConfig, **kwargs): # Setup the rollout store # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout - self.store = PPORolloutStorage(self.tokenizer.pad_token_id) + self.store = MRTRolloutStorage(self.tokenizer.pad_token_id) # Create the rollout store dataloader (for batching up rollouts) # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future @@ -143,12 +144,11 @@ def loss(self, batch: MRTRLBatch): # Move `batch` data to `accelerator` device query_tensors = batch.query_tensors.to(self.accelerator.device) response_tensors = batch.response_tensors.to(self.accelerator.device) - old_logprobs = batch.logprobs.to(self.accelerator.device) - old_values = batch.values.to(self.accelerator.device) - old_rewards = batch.rewards.to(self.accelerator.device) - response_length = old_rewards.shape[1] + logprobs = batch.logprobs.to(self.accelerator.device) + rewards = batch.rewards.to(self.accelerator.device) + response_length = rewards.shape[1] - advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) + # advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) if self.config.model.model_arch_type == "seq2seq": input_ids = query_tensors @@ -197,16 +197,12 @@ def loss(self, batch: MRTRLBatch): loss, stats = self.config.method.loss( logprobs=logprobs, - values=values_pred, - old_logprobs=old_logprobs, - old_values=old_values, - advantages=advantages, - returns=returns, + rewards=rewards, mask=mask, ) # TODO update this - self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats + # self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats return loss, stats def setup_rollout_logging(self, config): @@ -257,7 +253,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq Takes `chunk_size` number of prompts from `prompt_iterator`, samples from the model and then computes the KL against a reference model. Finally it - then appends PPOElements to trainer's `store`. + then appends MRTRLElements to trainer's `store`. Args: num_rollouts: Number of rollouts to generate @@ -275,11 +271,13 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq leave=logging.get_verbosity() < logging.WARNING, ) + num_candidates = self.config.method.num_candidates + mrt_rl_elements = [] stats = {} clock = Clock() - while len(mrt_rl_elements) < num_rollouts: + while len(mrt_rl_elements) * num_candidates < num_rollouts: # Get next batch in prompt dataset and refresh if exhausted # TOOD (jon-tow): Make `prompt_dataloader` a cyclic/infinite DataLoader to not require manually # "refreshing" the contents of the `prompt_iterator` @@ -293,13 +291,15 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Generate samples from the language model (similar to using HuggingFace `generate` method) # For MRT, this should generate num_candidates samples for each prompt in the batch + # So in total: [batch_size * num_candidates, response_len] samples = self.generate(**batch) device = samples.device # Expand queries and mask - copied_idxs = torch.tensor([i for i in range(batch.input_ids.shape[0]) for _ in range(self.config.method.num_candidates)], device=device) - batch.input_ids = torch.index_select(batch.input_ids, 0, copied_idxs) # [batch_size * candidate_size, query_length] - batch.attention_mask = torch.index_select(batch.attention_mask, 0, copied_idxs) # [batch_size * candidate_size, query_length] + copied_idxs = torch.tensor([i for i in range(batch.input_ids.shape[0]) for _ in range(num_candidates)], device=device) + # TODO change this part over here + batch.input_ids = torch.index_select(batch.input_ids, 0, copied_idxs) # [batch_size, candidate_size, query_length] + batch.attention_mask = torch.index_select(batch.attention_mask, 0, copied_idxs) # [batch_size, candidate_size, query_length] stats["time/exp_generate"] = time() - exp_generate_time @@ -489,28 +489,37 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count = 0 - for sample_idx in range(n_samples): - sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx] + for idx in range(n_samples // num_candidates): + sample_idxs = torch.arange( + idx * num_candidates, + (idx + 1) * num_candidates) + # k + # sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx] - if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0: - continue + # if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0: + # continue + + # not used for MRT: + # rewards = sample_kl_divergence_estimate + #rewards[-1] += scores[sample_idx].cpu() - rewards = sample_kl_divergence_estimate - rewards[-1] += scores[sample_idx].cpu() + ends_cands = ends[idx * num_candidates: (idx+1) * num_candidates] + rewards = torch.zeros_like(logprobs, dtype=torch.float32) + rewards[torch.arange(num_candidates), ends_cands - 1] = scores[sample_idxs].cpu() mrt_rl_elements.append( MRTRLElement( - query_tensor=prompt_tensors[sample_idx], - response_tensor=sample_outputs[sample_idx], - logprobs=all_logprobs[sample_idx], - values=all_values[sample_idx], - rewards=rewards, + query_tensor=prompt_tensors[sample_idxs].view(num_candidates, -1), + response_tensor=sample_outputs[sample_idxs].view(num_candidates, -1), + logprobs=logprobs[sample_idxs].view(num_candidates, -1), + values=values.contiguous().view(num_candidates, -1), + rewards=rewards ) ) - rollout_count += 1 + rollout_count += num_candidates exp_time = clock.tick() - tbar.set_description(f"[rollout {len(mrt_rl_elements)} / {num_rollouts}]") + tbar.set_description(f"[rollout {rollout_count} / {num_rollouts}]") tbar.update(min(rollout_count, num_rollouts)) tbar.close() From 3f4729b1ca831ba71a622f9259af35bf253f2525 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Fri, 17 Mar 2023 01:06:00 +0000 Subject: [PATCH 4/8] [WIP] MRT Training now actually training --- .../mrt_summarize_daily_cnn_t5.py | 6 ++-- trlx/models/modeling_mrt.py | 14 ++++++++- trlx/pipeline/mrt_pipeline.py | 30 +++++++++---------- trlx/trainer/accelerate_base_trainer.py | 5 ++-- trlx/trainer/accelerate_mrt_trainer.py | 27 +++++++++++------ trlx/trlx.py | 2 +- 6 files changed, 51 insertions(+), 33 deletions(-) diff --git a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py index 12b68625f..7d18467ee 100644 --- a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py +++ b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py @@ -29,10 +29,10 @@ total_steps=100000, batch_size=4, checkpoint_interval=10000, - eval_interval=500, + eval_interval=100, pipeline="PromptPipeline", trainer="AccelerateMRTTrainer", - tracker=None, + tracker="wandb", ), model=ModelConfig( model_path="google/flan-t5-small", @@ -64,7 +64,7 @@ # n_updates_per_batch=1, #### MRT num_rollouts=512, chunk_size=4, - ppo_epochs=4, + ppo_epochs=1, # init_kl_coef=0.05, # target=6, # horizon=10000, diff --git a/trlx/models/modeling_mrt.py b/trlx/models/modeling_mrt.py index 2944b4e15..90f09634b 100644 --- a/trlx/models/modeling_mrt.py +++ b/trlx/models/modeling_mrt.py @@ -91,7 +91,14 @@ def loss( - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html """ + # TODO: check if masking is correct + + n = mask.sum() + loss = torch.tensor(0.0) + + # we make the assumption here that we only care about sequence level rewards only + rewards = rewards.sum(dim=-1) costs = 1 - rewards # Reward component @@ -140,7 +147,12 @@ def loss( combined_loss = self.ce_loss_weight * ce_loss + (1 - self.ce_loss_weight) * loss stats = dict( - loss=dict(combined_loss=combined_loss, ce_loss=ce_loss, loss=loss, costs=costs), + losses=dict( + total_loss=combined_loss.item(), + ce_loss=ce_loss.item(), + mrt_loss=loss.item(), + ), + padding_percentage=n / mask.numel(), ) # stats = utils.apply_to_sample(lambda t: t.detach().cpu(), stats) diff --git a/trlx/pipeline/mrt_pipeline.py b/trlx/pipeline/mrt_pipeline.py index bab21356f..e433121e4 100644 --- a/trlx/pipeline/mrt_pipeline.py +++ b/trlx/pipeline/mrt_pipeline.py @@ -51,30 +51,28 @@ def create_loader( shuffle: bool, ) -> DataLoader: def collate_fn(elems: Iterable[MRTRLElement]): - return MRTRLBatch( - # Left padding of already left-padded queries + return MRTRLBatch( # TODO: make sure this is expected pad_sequence( - [elem.query_tensor.flip(0) for elem in elems], + [elem.query_tensor.transpose(0, 1) for elem in elems], padding_value=self.pad_token_id, - batch_first=True, - ).flip(1), + ).transpose(0, 1).transpose(1, 2), # Right pad the rest, to have a single horizontal query/response split pad_sequence( - [elem.response_tensor for elem in elems], + [elem.response_tensor.transpose(0,1) for elem in elems], padding_value=self.pad_token_id, - batch_first=True, - ), + ).transpose(0, 1).transpose(1, 2), pad_sequence( - [elem.logprobs for elem in elems], + [elem.logprobs.transpose(0, 1) for elem in elems], padding_value=0.0, - batch_first=True, - ), - pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True), + ).transpose(0, 1).transpose(1, 2), pad_sequence( - [elem.rewards for elem in elems], - padding_value=0.0, - batch_first=True, - ), + [elem.values.transpose(0, 1) for elem in elems], + padding_value=0.0 + ).transpose(0, 1).transpose(1, 2), + pad_sequence( + [elem.rewards.transpose(0, 1) for elem in elems], + padding_value=0.0 + ).transpose(0, 1).transpose(1, 2) ) return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 587215074..a8c1a3a01 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -451,9 +451,8 @@ def learn(self): # noqa: C901 state = json.load(f) self.iter_count = state["iter_count"] else: - # TODO: no eval for now: - # results = self.evaluate() - # self.accelerator.log(results, step=self.iter_count) + results = self.evaluate() + self.accelerator.log(results, step=self.iter_count) ... tbar = logging.tqdm( diff --git a/trlx/trainer/accelerate_mrt_trainer.py b/trlx/trainer/accelerate_mrt_trainer.py index 6e39c5701..482ba3dce 100644 --- a/trlx/trainer/accelerate_mrt_trainer.py +++ b/trlx/trainer/accelerate_mrt_trainer.py @@ -146,7 +146,16 @@ def loss(self, batch: MRTRLBatch): response_tensors = batch.response_tensors.to(self.accelerator.device) logprobs = batch.logprobs.to(self.accelerator.device) rewards = batch.rewards.to(self.accelerator.device) - response_length = rewards.shape[1] + + # remove middle dimension + batch_size = len(query_tensors) + num_candidates = self.config.method.num_candidates + query_tensors = query_tensors.reshape(batch_size * num_candidates, -1) + response_tensors = response_tensors.reshape(batch_size * num_candidates, -1) + logprobs = logprobs.reshape(batch_size * num_candidates, -1) + rewards = rewards.reshape(batch_size * num_candidates, -1) + response_length = rewards.shape[-1] + # advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) @@ -238,7 +247,8 @@ def prepare_learning(self): self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader) self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True) - self.n_updates_per_batch = self.config.method.ppo_epochs + # This should always be 1 for PPO + self.n_updates_per_batch = 1 self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) @@ -489,6 +499,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count = 0 + rewards = torch.zeros_like(logprobs, dtype=torch.float32) + rewards[torch.arange(len(rewards)), ends - 1] = scores.cpu() + for idx in range(n_samples // num_candidates): sample_idxs = torch.arange( idx * num_candidates, @@ -503,23 +516,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # rewards = sample_kl_divergence_estimate #rewards[-1] += scores[sample_idx].cpu() - ends_cands = ends[idx * num_candidates: (idx+1) * num_candidates] - rewards = torch.zeros_like(logprobs, dtype=torch.float32) - rewards[torch.arange(num_candidates), ends_cands - 1] = scores[sample_idxs].cpu() - mrt_rl_elements.append( MRTRLElement( query_tensor=prompt_tensors[sample_idxs].view(num_candidates, -1), response_tensor=sample_outputs[sample_idxs].view(num_candidates, -1), logprobs=logprobs[sample_idxs].view(num_candidates, -1), - values=values.contiguous().view(num_candidates, -1), - rewards=rewards + values=values[sample_idxs].view(num_candidates, -1), + rewards=rewards[sample_idxs].view(num_candidates, -1) ) ) rollout_count += num_candidates exp_time = clock.tick() - tbar.set_description(f"[rollout {rollout_count} / {num_rollouts}]") + tbar.set_description(f"[rollout {num_candidates * len(mrt_rl_elements)} / {num_rollouts}]") tbar.update(min(rollout_count, num_rollouts)) tbar.close() diff --git a/trlx/trlx.py b/trlx/trlx.py index f50753d14..c50fe739a 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -100,7 +100,7 @@ def train( # noqa: C901 trainer.make_experience(config.method.num_rollouts) - # Offline training from the collected samples (e.g. SFT, ILQL) + # Offline training from the collected samples (e.g. SFT, ILQL) elif samples: if rewards: if len(samples) != len(rewards): From 1b3cc441cd387ecaa8af5e550c46ab792d761618 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Tue, 21 Mar 2023 21:00:05 +0000 Subject: [PATCH 5/8] Creating debug file separately for MRT T5 --- .../debug_mrt_summarize_daily_cnn_t5.py | 173 ++++++++++++++++++ .../mrt_summarize_daily_cnn_t5.py | 12 +- .../t5_summarize_daily_cnn.py | 8 +- 3 files changed, 183 insertions(+), 10 deletions(-) create mode 100644 examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py diff --git a/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py new file mode 100644 index 000000000..98911fa1d --- /dev/null +++ b/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py @@ -0,0 +1,173 @@ +from typing import List + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_mrt import MRTConfig + +try: + import evaluate +except ImportError: + raise ImportError( + "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" + ) + +config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=4, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateMRTTrainer", + tracker=None + # tracker="wandb", + ), + model=ModelConfig( + model_path="google/flan-t5-small", + model_arch_type="seq2seq", + num_layers_unfrozen=2, + ), + tokenizer=TokenizerConfig( + tokenizer_path="google/flan-t5-small", #### change to reasonable value + truncation_side="right", # what is this? + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 1.0e-5, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=MRTConfig( + name="MRTConfig", + # n_updates_per_batch=1, #### MRT + num_rollouts=512, + chunk_size=4, + ppo_epochs=1, + # init_kl_coef=0.05, + # target=6, + # horizon=10000, + # gamma=0.99, + # lam=0.95, + # cliprange=0.2, + # cliprange_value=0.2, + # vf_coef=1.0, + num_candidates=16, + ce_loss_weight=0.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ # for evaluation + "max_new_tokens": 100, + # TODO: what should the defaults here be + }, + gen_experience_kwargs={ # for rollouts + "max_new_tokens": 100, + "num_beams": 16, # should be same as nb_candidates + "num_return_sequences": 16, # should be same as nb_candidates + "do_sample": False, + "temperature": 1.0, + # "top_k": 50, + # "top_p": 0.95, + }, + ), +) + + # gen_kwargs = { + # "min_length":-1, + # "top_k": config['top_k'], + # "top_p": 1.0, + # "temperature": config["temperature"], + # "do_sample": config['do_sample'], + # "num_beams": config['num_beams'], + # "max_length": config['max_length'], + # # "pad_token_id": model.eos_token_id, + # "num_return_sequences": config['candidate_size'], + # } + # eval_kwargs = { + # # "early_stopping": True, + # # "length_penalty": 2.0, + # "min_length":-1, + # "top_k": 0.0, + # # "top_p": 1.0, + # "do_sample": False, + # "num_beams": config['eval_num_beams'], + # # "no_repeat_ngram_size": 3, + # "max_length": config['max_length'], + # } + + +meteor = evaluate.load("meteor") # use meteor as the reward function + +if __name__ == "__main__": + + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): + original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] + scores = [ + meteor.compute(predictions=[output.strip()], references=[original])["meteor"] + for (original, output) in zip(original_summaries, outputs) + ] + return scores + + dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") + + # take 20,000 samples from the training set as prompts for training + prompts = dataset["train"]["article"][0:1200] + summaries = dataset["train"]["highlights"][0:1200] + prompts = ["Summarize: " + prompt for prompt in prompts] + + # take 1,000 samples from the validation set as prompts for evaluation + val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] + val_summaries = dataset["validation"]["highlights"][0:1000] + + # make dictionary of prompts and labels to use for reward function + tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "right" + tokenizer.sep_token = "" + prompt_label = {} + max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + + for i in tqdm(range(len(prompts))): + key = tokenizer.decode( + tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + prompt_label[key.strip()] = summaries[i] + + for i in tqdm(range(len(val_prompts))): + key = tokenizer.decode( + tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + prompt_label[key.strip()] = val_summaries[i] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=val_prompts, + config=config, + ) diff --git a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py index 7d18467ee..ab2dc56e8 100644 --- a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py +++ b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py @@ -27,7 +27,7 @@ seq_length=612, epochs=100, total_steps=100000, - batch_size=4, + batch_size=1, checkpoint_interval=10000, eval_interval=100, pipeline="PromptPipeline", @@ -35,12 +35,12 @@ tracker="wandb", ), model=ModelConfig( - model_path="google/flan-t5-small", + model_path="google/flan-t5-large", model_arch_type="seq2seq", num_layers_unfrozen=2, ), tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-small", #### change to reasonable value + tokenizer_path="google/flan-t5-large", #### change to reasonable value truncation_side="right", # what is this? ), optimizer=OptimizerConfig( @@ -63,7 +63,7 @@ name="MRTConfig", # n_updates_per_batch=1, #### MRT num_rollouts=512, - chunk_size=4, + chunk_size=1, ppo_epochs=1, # init_kl_coef=0.05, # target=6, @@ -134,8 +134,8 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") # take 20,000 samples from the training set as prompts for training - prompts = dataset["train"]["article"][0:1200] - summaries = dataset["train"]["highlights"][0:1200] + prompts = dataset["train"]["article"][0:20000] + summaries = dataset["train"]["highlights"][0:20000] prompts = ["Summarize: " + prompt for prompt in prompts] # take 1,000 samples from the validation set as prompts for evaluation diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 8520e76cf..4c3a56758 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -34,12 +34,12 @@ trainer="AcceleratePPOTrainer", ), model=ModelConfig( - model_path="google/flan-t5-small", + model_path="google/flan-t5-large", model_arch_type="seq2seq", num_layers_unfrozen=2, ), tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-small", + tokenizer_path="google/flan-t5-large", truncation_side="right", ), optimizer=OptimizerConfig( @@ -104,8 +104,8 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") # take 20,000 samples from the training set as prompts for training - prompts = dataset["train"]["article"][0:1200] - summaries = dataset["train"]["highlights"][0:1200] + prompts = dataset["train"]["article"][0:20000] + summaries = dataset["train"]["highlights"][0:20000] prompts = ["Summarize: " + prompt for prompt in prompts] # take 1,000 samples from the validation set as prompts for evaluation From fd7ba2f5d4541d3a7d3ddfee9f3d76500d1e0bf9 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Wed, 22 Mar 2023 07:11:43 +0000 Subject: [PATCH 6/8] Adding T5 translation task with ppo --- examples/ppo_translation_t5.py | 164 ++++++++++++++++++++++++ trlx/trainer/accelerate_base_trainer.py | 3 + 2 files changed, 167 insertions(+) create mode 100644 examples/ppo_translation_t5.py diff --git a/examples/ppo_translation_t5.py b/examples/ppo_translation_t5.py new file mode 100644 index 000000000..ba6f120c5 --- /dev/null +++ b/examples/ppo_translation_t5.py @@ -0,0 +1,164 @@ +from typing import List + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_ppo import PPOConfig + +try: + import evaluate +except ImportError: + raise ImportError( + "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" + ) + +import torch + +config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=12, + checkpoint_interval=10000, + eval_interval=500, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + tracker='wandb' + # tracker=None + ), + model=ModelConfig( + model_path="google/flan-t5-large", + model_arch_type="seq2seq", + num_layers_unfrozen=2, + ), + tokenizer=TokenizerConfig( + tokenizer_path="google/flan-t5-large", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 1.0e-5, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=PPOConfig( + name="PPOConfig", + num_rollouts=512, + chunk_size=12, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=0.99, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ + "max_new_tokens": 100, + }, + gen_experience_kwargs={ + "max_new_tokens": 100, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + "top_p": 0.95, + }, + ), +) + + +comet_metric = evaluate.load('comet', 'wmt20-comet-da') + + +if __name__ == "__main__": + + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): + # WHAT should samples be for translation? + original_sents = [translation_map[prompt.strip()] for prompt in prompts] + + scores = comet_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original['tgt'] for original in original_sents], + sources=[original['src'] for original in original_sents])["scores"] + + return scores + + + train_dataset = load_dataset("wmt16", "de-en", split='train', cache_dir="/home/aiscuser/dev/trlx_mrt/examples/notebooks/data", streaming=True) + valid_dataset = load_dataset("wmt16", "de-en", split='validation', cache_dir="/home/aiscuser/dev/trlx_mrt/examples/notebooks/data", streaming=False) + + + src_lang = 'en' + tgt_lang = 'de' + PREFIX = "translate English to German: " + + + # take 20,000 samples from the training set as prompts for training + original_src_dataset = [sent_pair['translation'][src_lang] for sent_pair in train_dataset.take(20000)] + tgt_dataset = [sent_pair['translation'][tgt_lang] for sent_pair in train_dataset.take(20000)] + src_dataset = [PREFIX + src_sent for src_sent in original_src_dataset] + + + # take 1,000 samples from the validation set as prompts for evaluation + val_original_src_dataset = [sent_pair[src_lang] for sent_pair in valid_dataset['translation'][0:1000]] + val_tgt_dataset = [sent_pair[tgt_lang] for sent_pair in valid_dataset['translation'][0:1000]] + val_src_dataset = [PREFIX + src_sent for src_sent in val_original_src_dataset] + + + # make dictionary of prompts and labels to use for reward function + tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "right" + tokenizer.sep_token = "" + max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + translation_map = {} + + + for i in tqdm(range(len(original_src_dataset))): + key = tokenizer.decode( + tokenizer(src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + translation_map[key.strip()] = {'src': original_src_dataset[i], 'tgt': tgt_dataset[i]} + + + for i in tqdm(range(len(val_original_src_dataset))): + key = tokenizer.decode( + tokenizer(val_src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + translation_map[key.strip()] = {'src': val_original_src_dataset[i], 'tgt': val_tgt_dataset[i]} + + + trlx.train( + reward_fn=reward_fn, + prompts=src_dataset, + eval_prompts=val_src_dataset, + config=config, + ) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index a8c1a3a01..b6b6ac92c 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -475,12 +475,15 @@ def learn(self): # noqa: C901 # gradient update per batch, PPO for example commonly performs # multiple gradient updates on the same batch of data. # https://arxiv.org/pdf/1707.06347.pdf + # TODO: remove + torch.use_deterministic_algorithms(False, warn_only=True) forward_time = time() loss, stats = self.loss(batch) forward_time = time() - forward_time backward_time = time() self.accelerator.backward(loss) backward_time = time() - backward_time + torch.use_deterministic_algorithms(True, warn_only=True) self.opt.step() self.opt.zero_grad() From 642a48f8d8660bce4139ab7ca3043d73e28fcea2 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Thu, 23 Mar 2023 00:37:05 +0000 Subject: [PATCH 7/8] Adding metric_fn to ppo translation example --- examples/ppo_translation_t5.py | 82 +++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/examples/ppo_translation_t5.py b/examples/ppo_translation_t5.py index ba6f120c5..befcceb99 100644 --- a/examples/ppo_translation_t5.py +++ b/examples/ppo_translation_t5.py @@ -22,7 +22,6 @@ "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" ) -import torch config = TRLConfig( train=TrainConfig( @@ -34,7 +33,7 @@ eval_interval=500, pipeline="PromptPipeline", trainer="AcceleratePPOTrainer", - tracker='wandb' + tracker="wandb" # tracker=None ), model=ModelConfig( @@ -93,44 +92,75 @@ ) -comet_metric = evaluate.load('comet', 'wmt20-comet-da') +comet_metric = evaluate.load("comet", "wmt20-comet-da") +bleu_metric = evaluate.load("bleu") +chrf_metric = evaluate.load("chrf") if __name__ == "__main__": - def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: # WHAT should samples be for translation? original_sents = [translation_map[prompt.strip()] for prompt in prompts] scores = comet_metric.compute( - predictions=[output.strip() for output in outputs], - references=[original['tgt'] for original in original_sents], - sources=[original['src'] for original in original_sents])["scores"] - + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + sources=[original["src"] for original in original_sents], + )["scores"] return scores + def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: + # Compute BLEU and CHRF + original_sents = [translation_map[prompt.strip()] for prompt in prompts] - train_dataset = load_dataset("wmt16", "de-en", split='train', cache_dir="/home/aiscuser/dev/trlx_mrt/examples/notebooks/data", streaming=True) - valid_dataset = load_dataset("wmt16", "de-en", split='validation', cache_dir="/home/aiscuser/dev/trlx_mrt/examples/notebooks/data", streaming=False) - + comet_score = comet_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + sources=[original["src"] for original in original_sents], + )['mean_score'] + + bleu_score = bleu_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + )['bleu'] + + chrf_score = bleu_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + )['score'] + + return { + 'bleu': bleu_score, + 'chrf': chrf_score, + 'comet': comet_score + } + + train_dataset = load_dataset( + "wmt16", "de-en", split="train", cache_dir="/home/aiscuser/dev/trlx_mrt/examples/notebooks/data", streaming=True + ) + valid_dataset = load_dataset( + "wmt16", + "de-en", + split="validation", + cache_dir="/home/aiscuser/dev/trlx_mrt/examples/notebooks/data", + streaming=False, + ) - src_lang = 'en' - tgt_lang = 'de' + src_lang = "en" + tgt_lang = "de" PREFIX = "translate English to German: " - # take 20,000 samples from the training set as prompts for training - original_src_dataset = [sent_pair['translation'][src_lang] for sent_pair in train_dataset.take(20000)] - tgt_dataset = [sent_pair['translation'][tgt_lang] for sent_pair in train_dataset.take(20000)] + original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in train_dataset.take(20000)] + tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in train_dataset.take(20000)] src_dataset = [PREFIX + src_sent for src_sent in original_src_dataset] - # take 1,000 samples from the validation set as prompts for evaluation - val_original_src_dataset = [sent_pair[src_lang] for sent_pair in valid_dataset['translation'][0:1000]] - val_tgt_dataset = [sent_pair[tgt_lang] for sent_pair in valid_dataset['translation'][0:1000]] + val_original_src_dataset = [sent_pair[src_lang] for sent_pair in valid_dataset["translation"][0:1000]] + val_tgt_dataset = [sent_pair[tgt_lang] for sent_pair in valid_dataset["translation"][0:1000]] val_src_dataset = [PREFIX + src_sent for src_sent in val_original_src_dataset] - # make dictionary of prompts and labels to use for reward function tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) tokenizer.padding_side = "left" @@ -139,25 +169,25 @@ def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] translation_map = {} - for i in tqdm(range(len(original_src_dataset))): key = tokenizer.decode( tokenizer(src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ) # get prompt like trlx's prompt - translation_map[key.strip()] = {'src': original_src_dataset[i], 'tgt': tgt_dataset[i]} - + translation_map[key.strip()] = {"src": original_src_dataset[i], "tgt": tgt_dataset[i]} for i in tqdm(range(len(val_original_src_dataset))): key = tokenizer.decode( - tokenizer(val_src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + tokenizer(val_src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)[ + "input_ids" + ], skip_special_tokens=True, ) # get prompt like trlx's prompt - translation_map[key.strip()] = {'src': val_original_src_dataset[i], 'tgt': val_tgt_dataset[i]} - + translation_map[key.strip()] = {"src": val_original_src_dataset[i], "tgt": val_tgt_dataset[i]} trlx.train( reward_fn=reward_fn, + metric_fn=metric_fn, prompts=src_dataset, eval_prompts=val_src_dataset, config=config, From 945f344d51d6f04b17a4346faa9d1945339b33f9 Mon Sep 17 00:00:00 2001 From: alexandremuzio Date: Mon, 10 Apr 2023 05:24:29 +0000 Subject: [PATCH 8/8] Removing some unused stuff + some improvements + fixing formatting --- examples/mrt_translation_t5.py | 210 ++++++++++++++++++ .../debug_mrt_summarize_daily_cnn_t5.py | 57 ++--- .../mrt_summarize_daily_cnn_t5.py | 172 -------------- trlx/data/mrt_types.py | 2 +- trlx/models/modeling_mrt.py | 132 ++--------- trlx/pipeline/mrt_pipeline.py | 30 +-- trlx/reference.py | 3 +- trlx/sweep.py | 3 +- trlx/trainer/accelerate_mrt_trainer.py | 138 ++++-------- 9 files changed, 325 insertions(+), 422 deletions(-) create mode 100644 examples/mrt_translation_t5.py delete mode 100644 examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py diff --git a/examples/mrt_translation_t5.py b/examples/mrt_translation_t5.py new file mode 100644 index 000000000..da9fff647 --- /dev/null +++ b/examples/mrt_translation_t5.py @@ -0,0 +1,210 @@ +"""Example of using PPO to train a T5 model for translation. +Based on examples/summarize_daily_cnn/t5_summarize_daily_cnn.py""" + +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +import trlx +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_mrt import MRTConfig + +try: + import comet + import evaluate + + if comet.__version__ != "1.1.3": + raise ImportError +except ImportError: + raise ImportError( + "To run this example, please install `evaluate`, `nltk` and `comet==1.1.3` packages by " + "running `pip install evaluate unbabel-comet==1.1.3`" + ) + + +default_config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=4, + checkpoint_interval=10000, + eval_interval=64, + pipeline="PromptPipeline", + trainer="AccelerateMRTTrainer", + # tracker=None + tracker="wandb", + ), + model=ModelConfig( + model_path="t5-small", + model_arch_type="seq2seq", + num_layers_unfrozen=-1, + ), + tokenizer=TokenizerConfig( + tokenizer_path="t5-small", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 2.0e-6, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=MRTConfig( + name="MRTConfig", + num_rollouts=512, + chunk_size=4, + num_candidates=16, + ce_loss_weight=0.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ # for evaluation + "max_new_tokens": 100, + # TODO: what should the defaults here be + }, + gen_experience_kwargs={ # for rollouts + "max_new_tokens": 100, + "num_beams": 16, # should be same as nb_candidates + "num_return_sequences": 16, # should be same as nb_candidates + "do_sample": False, + "temperature": 1.0, + # "top_k": 50, + # "top_p": 0.95, + }, + ), +) + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + # COMET is the metric we are optimizng for + comet_metric = evaluate.load("comet", "wmt20-comet-da", progress_bar=False) + bleu_metric = evaluate.load("bleu") + chrf_metric = evaluate.load("chrf") + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: + original_sents = [translation_map[prompt.strip()] for prompt in prompts] + + scores = comet_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + sources=[original["src"] for original in original_sents], + )["scores"] + + # TODO: This is needed since there seems to be a bug in the comet metric + # that changes torch's determinism setting. Remove this once the bug is fixed. + torch.use_deterministic_algorithms(False, warn_only=True) + return scores + + def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: + """Compute COMET, BLEU and CHRF for evaluation""" + original_sents = [translation_map[prompt.strip()] for prompt in prompts] + + comet_score = comet_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + sources=[original["src"] for original in original_sents], + )["mean_score"] + + bleu_score = bleu_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + )["bleu"] + + chrf_score = chrf_metric.compute( + predictions=[output.strip() for output in outputs], + references=[original["tgt"] for original in original_sents], + )["score"] + + # TODO: This is needed since there seems to be a bug in the comet metric + # that changes torch's determinism setting. Remove this once the bug is fixed. + # Same issue as in `reward_fn` + torch.use_deterministic_algorithms(False, warn_only=True) + + # For corpus-level metrics, it's better to ignore the sentence-level scores + return {"bleu": bleu_score, "chrf": chrf_score, "comet": comet_score} + + # The WMT16 is large so we can benefit with using it as a streaming dataset + train_dataset = load_dataset("wmt16", "de-en", split="train", streaming=True) + valid_dataset = load_dataset("wmt16", "de-en", split="validation", streaming=True) + + src_lang = "en" + tgt_lang = "de" + PREFIX = "translate English to German: " + + # take 20,000 samples from the training set as prompts for training + # TODO: update to 20k + original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in train_dataset.take(1200)] + tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in train_dataset.take(1200)] + src_dataset = [PREFIX + src_sent for src_sent in original_src_dataset] + + # take 1,000 samples from the validation set as prompts for evaluation + val_original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in valid_dataset.take(1000)] + val_tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in valid_dataset.take(1000)] + val_src_dataset = [PREFIX + src_sent for src_sent in val_original_src_dataset] + + # make dictionary of prompts and labels to use for reward function + tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "right" + tokenizer.sep_token = "" + max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + translation_map = {} + + for i in tqdm(range(len(original_src_dataset))): + key = tokenizer.decode( + tokenizer(src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + translation_map[key.strip()] = {"src": original_src_dataset[i], "tgt": tgt_dataset[i]} + + for i in tqdm(range(len(val_original_src_dataset))): + key = tokenizer.decode( + tokenizer(val_src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)[ + "input_ids" + ], + skip_special_tokens=True, + ) # get prompt like trlx's prompt + translation_map[key.strip()] = {"src": val_original_src_dataset[i], "tgt": val_tgt_dataset[i]} + + trlx.train( + reward_fn=reward_fn, + metric_fn=metric_fn, + prompts=src_dataset, + eval_prompts=val_src_dataset, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py index 98911fa1d..ab9c8e1d8 100644 --- a/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py +++ b/examples/summarize_daily_cnn/debug_mrt_summarize_daily_cnn_t5.py @@ -1,3 +1,4 @@ +# DO NOT REVIEW, WILL BE DELETED from typing import List from datasets import load_dataset @@ -41,8 +42,8 @@ num_layers_unfrozen=2, ), tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-small", #### change to reasonable value - truncation_side="right", # what is this? + tokenizer_path="google/flan-t5-small", # change to reasonable value + truncation_side="right", # what is this? ), optimizer=OptimizerConfig( name="adamw", @@ -80,14 +81,14 @@ ref_mean=None, ref_std=None, cliprange_reward=10, - gen_kwargs={ # for evaluation + gen_kwargs={ # for evaluation "max_new_tokens": 100, # TODO: what should the defaults here be }, - gen_experience_kwargs={ # for rollouts + gen_experience_kwargs={ # for rollouts "max_new_tokens": 100, - "num_beams": 16, # should be same as nb_candidates - "num_return_sequences": 16, # should be same as nb_candidates + "num_beams": 16, # should be same as nb_candidates + "num_return_sequences": 16, # should be same as nb_candidates "do_sample": False, "temperature": 1.0, # "top_k": 50, @@ -96,28 +97,28 @@ ), ) - # gen_kwargs = { - # "min_length":-1, - # "top_k": config['top_k'], - # "top_p": 1.0, - # "temperature": config["temperature"], - # "do_sample": config['do_sample'], - # "num_beams": config['num_beams'], - # "max_length": config['max_length'], - # # "pad_token_id": model.eos_token_id, - # "num_return_sequences": config['candidate_size'], - # } - # eval_kwargs = { - # # "early_stopping": True, - # # "length_penalty": 2.0, - # "min_length":-1, - # "top_k": 0.0, - # # "top_p": 1.0, - # "do_sample": False, - # "num_beams": config['eval_num_beams'], - # # "no_repeat_ngram_size": 3, - # "max_length": config['max_length'], - # } +# gen_kwargs = { +# "min_length":-1, +# "top_k": config['top_k'], +# "top_p": 1.0, +# "temperature": config["temperature"], +# "do_sample": config['do_sample'], +# "num_beams": config['num_beams'], +# "max_length": config['max_length'], +# # "pad_token_id": model.eos_token_id, +# "num_return_sequences": config['candidate_size'], +# } +# eval_kwargs = { +# # "early_stopping": True, +# # "length_penalty": 2.0, +# "min_length":-1, +# "top_k": 0.0, +# # "top_p": 1.0, +# "do_sample": False, +# "num_beams": config['eval_num_beams'], +# # "no_repeat_ngram_size": 3, +# "max_length": config['max_length'], +# } meteor = evaluate.load("meteor") # use meteor as the reward function diff --git a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py b/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py deleted file mode 100644 index ab2dc56e8..000000000 --- a/examples/summarize_daily_cnn/mrt_summarize_daily_cnn_t5.py +++ /dev/null @@ -1,172 +0,0 @@ -from typing import List - -from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoTokenizer - -import trlx -from trlx.data.configs import ( - ModelConfig, - OptimizerConfig, - SchedulerConfig, - TokenizerConfig, - TrainConfig, - TRLConfig, -) -from trlx.models.modeling_mrt import MRTConfig - -try: - import evaluate -except ImportError: - raise ImportError( - "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" - ) - -config = TRLConfig( - train=TrainConfig( - seq_length=612, - epochs=100, - total_steps=100000, - batch_size=1, - checkpoint_interval=10000, - eval_interval=100, - pipeline="PromptPipeline", - trainer="AccelerateMRTTrainer", - tracker="wandb", - ), - model=ModelConfig( - model_path="google/flan-t5-large", - model_arch_type="seq2seq", - num_layers_unfrozen=2, - ), - tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-large", #### change to reasonable value - truncation_side="right", # what is this? - ), - optimizer=OptimizerConfig( - name="adamw", - kwargs={ - "lr": 1.0e-5, - "betas": [0.9, 0.999], - "eps": 1.0e-8, - "weight_decay": 1.0e-6, - }, - ), - scheduler=SchedulerConfig( - name="cosine_annealing", - kwargs={ - "T_max": 10000, - "eta_min": 1.0e-6, - }, - ), - method=MRTConfig( - name="MRTConfig", - # n_updates_per_batch=1, #### MRT - num_rollouts=512, - chunk_size=1, - ppo_epochs=1, - # init_kl_coef=0.05, - # target=6, - # horizon=10000, - # gamma=0.99, - # lam=0.95, - # cliprange=0.2, - # cliprange_value=0.2, - # vf_coef=1.0, - num_candidates=16, - ce_loss_weight=0.0, - scale_reward=None, - ref_mean=None, - ref_std=None, - cliprange_reward=10, - gen_kwargs={ # for evaluation - "max_new_tokens": 100, - # TODO: what should the defaults here be - }, - gen_experience_kwargs={ # for rollouts - "max_new_tokens": 100, - "num_beams": 16, # should be same as nb_candidates - "num_return_sequences": 16, # should be same as nb_candidates - "do_sample": False, - "temperature": 1.0, - # "top_k": 50, - # "top_p": 0.95, - }, - ), -) - - # gen_kwargs = { - # "min_length":-1, - # "top_k": config['top_k'], - # "top_p": 1.0, - # "temperature": config["temperature"], - # "do_sample": config['do_sample'], - # "num_beams": config['num_beams'], - # "max_length": config['max_length'], - # # "pad_token_id": model.eos_token_id, - # "num_return_sequences": config['candidate_size'], - # } - # eval_kwargs = { - # # "early_stopping": True, - # # "length_penalty": 2.0, - # "min_length":-1, - # "top_k": 0.0, - # # "top_p": 1.0, - # "do_sample": False, - # "num_beams": config['eval_num_beams'], - # # "no_repeat_ngram_size": 3, - # "max_length": config['max_length'], - # } - - -meteor = evaluate.load("meteor") # use meteor as the reward function - -if __name__ == "__main__": - - def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): - original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] - scores = [ - meteor.compute(predictions=[output.strip()], references=[original])["meteor"] - for (original, output) in zip(original_summaries, outputs) - ] - return scores - - dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") - - # take 20,000 samples from the training set as prompts for training - prompts = dataset["train"]["article"][0:20000] - summaries = dataset["train"]["highlights"][0:20000] - prompts = ["Summarize: " + prompt for prompt in prompts] - - # take 1,000 samples from the validation set as prompts for evaluation - val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] - val_summaries = dataset["validation"]["highlights"][0:1000] - - # make dictionary of prompts and labels to use for reward function - tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) - tokenizer.padding_side = "left" - tokenizer.truncation_side = "right" - tokenizer.sep_token = "" - prompt_label = {} - max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] - - for i in tqdm(range(len(prompts))): - key = tokenizer.decode( - tokenizer(prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], - skip_special_tokens=True, - ) # get prompt like trlx's prompt - prompt_label[key.strip()] = summaries[i] - - for i in tqdm(range(len(val_prompts))): - key = tokenizer.decode( - tokenizer(val_prompts[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], - skip_special_tokens=True, - ) # get prompt like trlx's prompt - prompt_label[key.strip()] = val_summaries[i] - - trlx.train( - reward_fn=reward_fn, - prompts=prompts, - eval_prompts=val_prompts, - config=config, - ) diff --git a/trlx/data/mrt_types.py b/trlx/data/mrt_types.py index ac1ffe46c..e8c70b0d3 100644 --- a/trlx/data/mrt_types.py +++ b/trlx/data/mrt_types.py @@ -40,7 +40,7 @@ class MRTRLElement: @dataclass class MRTRLBatch: """ - A batched version of the PPORLElement. See PPORLElement for more details on individual fields. + A batched version of the MRTRLElement. See MRTRLElement for more details on individual fields. :param query_tensors: A batch of query tensors. Should be a long tensor. :type query_tensors: torch.Tensor diff --git a/trlx/models/modeling_mrt.py b/trlx/models/modeling_mrt.py index 90f09634b..42d61a48a 100644 --- a/trlx/models/modeling_mrt.py +++ b/trlx/models/modeling_mrt.py @@ -1,19 +1,12 @@ -from copy import deepcopy from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional -import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F from torchtyping import TensorType from trlx.data.method_configs import MethodConfig, register_method -from trlx.models.modeling_base import PreTrainedModelWrapper -from trlx.utils.modeling import ( - flatten_dict, - get_tensor_stats, -) +from trlx.utils.modeling import flatten_dict @dataclass @@ -22,37 +15,12 @@ class MRTConfig(MethodConfig): """ Config for MRT method - :param ppo_epochs: Number of updates per batch - :type ppo_epochs: int - :param num_rollouts: Number of experiences to observe before learning :type num_rollouts: int - :param init_kl_coef: Initial value for KL coefficient - :type init_kl_coef: float - - :param target: Target value for KL coefficient - :type target: float - - :param horizon: Number of steps for KL coefficient to reach target - :type horizon: int - :param gamma: Discount factor :type gamma: float - :param lam: GAE lambda - :type lam: float - - :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange) - :type cliprange: float - - :param cliprange_value: Clipping range for predicted values - (observed values - cliprange_value, observed values + cliprange_value) - :type cliprange_value: float - - :param vf_coef: Value loss scale w.r.t policy loss - :type vf_coef: float - :param gen_kwargs: Additioanl kwargs for the generation :type gen_kwargs: Dict[str, Any] @@ -60,17 +28,8 @@ class MRTConfig(MethodConfig): :type gen_experience_kwargs: Dict[str, Any] """ - ppo_epochs: int num_rollouts: int chunk_size: int - # init_kl_coef: float - # target: float - # horizon: int - # gamma: float - # lam: float - # cliprange: float - # cliprange_value: float - # vf_coef: float num_candidates: int ce_loss_weight: float scale_reward: Optional[str] @@ -86,33 +45,30 @@ def loss( rewards: TensorType["batch_size", "response_size"], mask: TensorType["batch_size", "response_size"], ): - """MRT objective function. - References: - - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html - """ + """MRT objective function.""" # TODO: check if masking is correct n = mask.sum() loss = torch.tensor(0.0) - + # we make the assumption here that we only care about sequence level rewards only rewards = rewards.sum(dim=-1) costs = 1 - rewards # Reward component - if self.ce_loss_weight < 1.0: # if ce_loss_weight is 1.0, then we only use the ce loss + if self.ce_loss_weight < 1.0: # if ce_loss_weight is 1.0, then we only use the ce loss # We make the assumption here that rewards are scaled to [0,1] # lengths = response_masks.sum(dim=-1).float() lengths = mask.sum(dim=-1).float() -# model_outputs = self.model( -# input_ids=queries, -# decoder_input_ids=responses, # response tokens are already shifted right (start token = pad token) -# attention_mask=query_masks -# return_dict=True) -# , + # model_outputs = self.model( + # input_ids=queries, + # decoder_input_ids=responses, # response tokens are already shifted right + # attention_mask=query_masks + # return_dict=True) + # , avg_scores = logprobs.sum(dim=-1) / lengths # [batch_size, candidate_size] @@ -125,7 +81,8 @@ def loss( # Cross entropy component ce_loss = torch.tensor(0.0) if self.ce_loss_weight > 0.0: - assert False, 'ce_loss_weight should be 0.0' + # TODO: for this to work we need to have some sort of reference + assert False, "ce_loss_weight should be 0.0" # if parallel_mask is not None: # queries = queries[parallel_mask] # query_masks = query_masks[parallel_mask] @@ -135,7 +92,8 @@ def loss( # # We should compute the cross entropy with the reference response and not with the generated response # model_outputs = self.model( # input_ids=queries, - # decoder_input_ids=shift_tokens_right(refs, self.model.config.pad_token_id, self.model.config.decoder_start_token_id), + # decoder_input_ids= + # shift_tokens_right(refs, self.model.config.pad_token_id, self.model.config.decoder_start_token_id), # attention_mask=query_masks, # return_dict=True) # ce_loss = F.cross_entropy( @@ -154,65 +112,5 @@ def loss( ), padding_percentage=n / mask.numel(), ) - # stats = utils.apply_to_sample(lambda t: t.detach().cpu(), stats) - - return combined_loss, flatten_dict(stats) - - - -""" - def loss(self, rewards, queries, responses, query_masks, response_masks, refs, ref_masks, parallel_mask=None): - loss = torch.tensor(0.0) - costs = 1 - rewards - - # Reward component - if self.params['ce_loss_weight'] < 1.0: # if ce_loss_weight is 1.0, then we only use the ce loss - # We make the assumption here that rewards are scaled to [0,1] - lengths = response_masks.sum(dim=-1).float() - - model_outputs = self.model( - input_ids=queries, - decoder_input_ids=responses, # response tokens are already shifted right (start token = pad token) - attention_mask=query_masks, - return_dict=True) - - logprobs = logprobs_from_logits(model_outputs.logits[:,:-1,:], responses[:, 1:], mask=response_masks[:, 1:]) - avg_scores = logprobs.sum(dim=-1) / lengths - - # [batch_size, candidate_size] - avg_scores = avg_scores.view(-1, self.params['candidate_size']) - costs = costs.view(-1, self.params['candidate_size']) - - probs = F.softmax(avg_scores, dim=1).squeeze(-1) - loss = (probs * costs).sum() - - # Cross entropy component - ce_loss = torch.tensor(0.0) - if self.params['ce_loss_weight'] > 0.0: - if parallel_mask is not None: - queries = queries[parallel_mask] - query_masks = query_masks[parallel_mask] - refs = refs[parallel_mask] - ref_masks = ref_masks[parallel_mask] - - # We should compute the cross entropy with the reference response and not with the generated response - model_outputs = self.model( - input_ids=queries, - decoder_input_ids=shift_tokens_right(refs, self.model.config.pad_token_id, self.model.config.decoder_start_token_id), - attention_mask=query_masks, - return_dict=True) - ce_loss = F.cross_entropy( - model_outputs.logits.reshape(-1, model_outputs.logits.size(-1)), - refs.reshape(-1), - ignore_index=self.model.config.pad_token_id - ) - - combined_loss = self.params['ce_loss_weight'] * ce_loss + (1 - self.params['ce_loss_weight']) * loss - - stats = dict( - loss=dict(combined_loss=combined_loss, ce_loss=ce_loss, loss=loss, costs=costs), - ) - stats = utils.apply_to_sample(lambda t: t.detach().cpu(), stats) return combined_loss, flatten_dict(stats) -""" diff --git a/trlx/pipeline/mrt_pipeline.py b/trlx/pipeline/mrt_pipeline.py index e433121e4..6c2948d17 100644 --- a/trlx/pipeline/mrt_pipeline.py +++ b/trlx/pipeline/mrt_pipeline.py @@ -51,28 +51,32 @@ def create_loader( shuffle: bool, ) -> DataLoader: def collate_fn(elems: Iterable[MRTRLElement]): - return MRTRLBatch( # TODO: make sure this is expected + return MRTRLBatch( # TODO: make sure this is expected pad_sequence( [elem.query_tensor.transpose(0, 1) for elem in elems], padding_value=self.pad_token_id, - ).transpose(0, 1).transpose(1, 2), + ) + .transpose(0, 1) + .transpose(1, 2), # Right pad the rest, to have a single horizontal query/response split pad_sequence( - [elem.response_tensor.transpose(0,1) for elem in elems], + [elem.response_tensor.transpose(0, 1) for elem in elems], padding_value=self.pad_token_id, - ).transpose(0, 1).transpose(1, 2), + ) + .transpose(0, 1) + .transpose(1, 2), pad_sequence( [elem.logprobs.transpose(0, 1) for elem in elems], padding_value=0.0, - ).transpose(0, 1).transpose(1, 2), - pad_sequence( - [elem.values.transpose(0, 1) for elem in elems], - padding_value=0.0 - ).transpose(0, 1).transpose(1, 2), - pad_sequence( - [elem.rewards.transpose(0, 1) for elem in elems], - padding_value=0.0 - ).transpose(0, 1).transpose(1, 2) + ) + .transpose(0, 1) + .transpose(1, 2), + pad_sequence([elem.values.transpose(0, 1) for elem in elems], padding_value=0.0) + .transpose(0, 1) + .transpose(1, 2), + pad_sequence([elem.rewards.transpose(0, 1) for elem in elems], padding_value=0.0) + .transpose(0, 1) + .transpose(1, 2), ) return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) diff --git a/trlx/reference.py b/trlx/reference.py index dab6b6d97..c4f4612a7 100644 --- a/trlx/reference.py +++ b/trlx/reference.py @@ -4,9 +4,10 @@ import os import subprocess -import wandb import wandb.apis.reports as wb +import wandb + parser = argparse.ArgumentParser() parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`") parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch") diff --git a/trlx/sweep.py b/trlx/sweep.py index 615cb7361..9bfc07495 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -5,7 +5,6 @@ from datetime import datetime import ray -import wandb import wandb.apis.reports as wb import yaml from ray import tune @@ -13,6 +12,8 @@ from ray.train.huggingface.accelerate import AccelerateTrainer from ray.tune.logger import CSVLoggerCallback +import wandb + def get_param_space(config: dict): # noqa: C901 """Get the param space from the config file.""" diff --git a/trlx/trainer/accelerate_mrt_trainer.py b/trlx/trainer/accelerate_mrt_trainer.py index 482ba3dce..c03bb4338 100644 --- a/trlx/trainer/accelerate_mrt_trainer.py +++ b/trlx/trainer/accelerate_mrt_trainer.py @@ -8,7 +8,6 @@ import torch import torch.nn.functional as F import transformers -from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -16,14 +15,12 @@ from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.configs import TRLConfig from trlx.data.mrt_types import MRTRLBatch, MRTRLElement -from trlx.models.modeling_ppo import ( - AdaptiveKLController, +from trlx.models.modeling_ppo import ( # TODO: do we need to update this to MRT? AutoModelForCausalLMWithHydraValueHead, AutoModelForSeq2SeqLMWithHydraValueHead, - FixedKLController, ) -from trlx.pipeline.offline_pipeline import PromptPipeline from trlx.pipeline.mrt_pipeline import MRTRolloutStorage +from trlx.pipeline.offline_pipeline import PromptPipeline from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer from trlx.utils import Clock @@ -34,13 +31,13 @@ @register_trainer class AccelerateMRTTrainer(AccelerateRLTrainer): - """PPO Accelerate Trainer""" + """MRT Accelerate Trainer""" reward_fn: Callable[[List[str], List[str], List[str]], List[float]] tokenizer: AutoTokenizer def __init__(self, config: TRLConfig, **kwargs): - """PPO Accelerate Trainer initialization + """MRT Accelerate Trainer initialization Args: config: Config @@ -75,13 +72,6 @@ def __init__(self, config: TRLConfig, **kwargs): self.ref_model.to(self.accelerator.device) self.ref_model.eval() - # Setup the KL controller - # This helps prevent large divergences in the controller (policy) - # if config.method.target is not None: - # self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon) - # else: - # self.kl_ctl = FixedKLController(config.method.init_kl_coef) - # Create the parameters for the Hugging Face language model's generator # method (that generates new tokens from a prompt). # https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate @@ -156,9 +146,6 @@ def loss(self, batch: MRTRLBatch): rewards = rewards.reshape(batch_size * num_candidates, -1) response_length = rewards.shape[-1] - - # advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) - if self.config.model.model_arch_type == "seq2seq": input_ids = query_tensors decoder_input_ids = response_tensors @@ -210,8 +197,6 @@ def loss(self, batch: MRTRLBatch): mask=mask, ) - # TODO update this - # self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats return loss, stats def setup_rollout_logging(self, config): @@ -239,15 +224,13 @@ def post_epoch_callback(self): self.make_experience(self.config.method.num_rollouts, self.iter_count) def post_backward_callback(self): - ... - # self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size) + pass def prepare_learning(self): eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader) self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True) - # This should always be 1 for PPO self.n_updates_per_batch = 1 self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) @@ -306,10 +289,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq device = samples.device # Expand queries and mask - copied_idxs = torch.tensor([i for i in range(batch.input_ids.shape[0]) for _ in range(num_candidates)], device=device) + copied_idxs = torch.tensor( + [i for i in range(batch.input_ids.shape[0]) for _ in range(num_candidates)], device=device + ) # TODO change this part over here - batch.input_ids = torch.index_select(batch.input_ids, 0, copied_idxs) # [batch_size, candidate_size, query_length] - batch.attention_mask = torch.index_select(batch.attention_mask, 0, copied_idxs) # [batch_size, candidate_size, query_length] + batch.input_ids = torch.index_select( + batch.input_ids, 0, copied_idxs + ) # [batch_size, candidate_size, query_length] + batch.attention_mask = torch.index_select( + batch.attention_mask, 0, copied_idxs + ) # [batch_size, candidate_size, query_length] stats["time/exp_generate"] = time() - exp_generate_time @@ -351,7 +340,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = torch.empty(len(samples), device=device) torch.distributed.scatter(scores, all_scores) else: - scores = all_scores[0].clone() # torch.tensor(all_scores[0]) + scores = all_scores[0].clone() # torch.tensor(all_scores[0]) str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) @@ -424,36 +413,34 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: - assert False - # else: - # all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) - # attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) - # with torch.no_grad(): - # logits, *_, values = self.model( - # all_tokens, - # attention_mask=attention_mask, - # ) - # # TODO(dahoas): When hydra model works need to also support generation on hydra head - # if hasattr(self.model, "frozen_head"): - # ref_logits = self.model.forward_hydra( - # all_tokens, - # attention_mask=attention_mask, - # return_dict=True, - # ).logits - # else: - # ref_logits = self.ref_model( - # all_tokens, - # attention_mask=attention_mask, - # return_dict=True, - # ).logits - # ref_logits = ref_logits.to(device) + all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) + attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens, + attention_mask=attention_mask, + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) if self.config.model.model_arch_type == "seq2seq": logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) - # else: - # logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - # ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + else: + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) n_samples: int = samples.shape[0] logprobs = logprobs.cpu() @@ -471,31 +458,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq padding_token: int = 0 ends = (sample_outputs[:, start:] != padding_token).sum(1) - # Get the logprobs and values, for tokens that are not padding - # or beginning of sequences tokens. These are from the model - # (not the reference model) - all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - - kl_divergence_estimate: List[torch.Tensor] = [ - 1.0 #-self.kl_ctl.value - * ( - logprobs[sample_idx, start : ends[sample_idx]] - - ref_logprobs[sample_idx, start : ends[sample_idx]] - ) - for sample_idx in range(n_samples) - ] - # Else if not seq2seq (i.e. causal) - # else: - # values = values.cpu()[:, :-1] - # start = prompt_tensors.shape[1] - 1 - # ends = start + attention_mask[:, start:].sum(1) - # all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] - # all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + else: + values = values.cpu()[:, :-1] + start = prompt_tensors.shape[1] - 1 + ends = start + attention_mask[:, start:].sum(1) + # all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + # all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - # kl_divergence_estimate = 1.0 * (logprobs - ref_logprobs) - # kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] + # kl_divergence_estimate = 1.0 * (logprobs - ref_logprobs) + # kl_divergence_estimate = [rs[start : ends[ix]] for ix, rs in enumerate(kl_divergence_estimate)] rollout_count = 0 @@ -503,18 +475,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards[torch.arange(len(rewards)), ends - 1] = scores.cpu() for idx in range(n_samples // num_candidates): - sample_idxs = torch.arange( - idx * num_candidates, - (idx + 1) * num_candidates) - # k - # sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx] - - # if len(sample_kl_divergence_estimate) == 0 or len(all_logprobs[sample_idx]) == 0: - # continue - - # not used for MRT: - # rewards = sample_kl_divergence_estimate - #rewards[-1] += scores[sample_idx].cpu() + sample_idxs = torch.arange(idx * num_candidates, (idx + 1) * num_candidates) mrt_rl_elements.append( MRTRLElement( @@ -522,7 +483,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq response_tensor=sample_outputs[sample_idxs].view(num_candidates, -1), logprobs=logprobs[sample_idxs].view(num_candidates, -1), values=values[sample_idxs].view(num_candidates, -1), - rewards=rewards[sample_idxs].view(num_candidates, -1) + rewards=rewards[sample_idxs].view(num_candidates, -1), ) ) @@ -532,7 +493,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq tbar.update(min(rollout_count, num_rollouts)) tbar.close() - # stats["kl_ctl_value"] = self.kl_ctl.value stats["time/exp"] = exp_time if not ray.is_initialized():