From f24281c88bdb6dfb7f21f2d7cd232626e0a904d5 Mon Sep 17 00:00:00 2001 From: Alessandro Pietro Bardelli Date: Tue, 27 Jun 2023 14:40:09 +0000 Subject: [PATCH 01/23] RLHF end2end example --- examples/rlhf/.gitignore | 4 + examples/rlhf/README.md | 45 ++++ examples/rlhf/config/train.yaml | 30 +++ examples/rlhf/config/train_reward.yaml | 32 +++ examples/rlhf/config/train_rlhf.yaml | 36 +++ examples/rlhf/data/__init__.py | 3 + examples/rlhf/models/__init__.py | 4 + examples/rlhf/models/actor_critic.py | 29 +++ examples/rlhf/models/reward.py | 34 +++ examples/rlhf/models/transformer.py | 44 ++++ examples/rlhf/requirements.txt | 11 + examples/rlhf/train.py | 155 +++++++++++ examples/rlhf/train_reward.py | 164 ++++++++++++ examples/rlhf/train_rlhf.py | 339 +++++++++++++++++++++++++ examples/rlhf/utils.py | 47 ++++ 15 files changed, 977 insertions(+) create mode 100644 examples/rlhf/.gitignore create mode 100644 examples/rlhf/README.md create mode 100644 examples/rlhf/config/train.yaml create mode 100644 examples/rlhf/config/train_reward.yaml create mode 100644 examples/rlhf/config/train_rlhf.yaml create mode 100644 examples/rlhf/data/__init__.py create mode 100644 examples/rlhf/models/__init__.py create mode 100644 examples/rlhf/models/actor_critic.py create mode 100644 examples/rlhf/models/reward.py create mode 100644 examples/rlhf/models/transformer.py create mode 100644 examples/rlhf/requirements.txt create mode 100644 examples/rlhf/train.py create mode 100644 examples/rlhf/train_reward.py create mode 100644 examples/rlhf/train_rlhf.py create mode 100644 examples/rlhf/utils.py diff --git a/examples/rlhf/.gitignore b/examples/rlhf/.gitignore new file mode 100644 index 00000000000..d8bad909a58 --- /dev/null +++ b/examples/rlhf/.gitignore @@ -0,0 +1,4 @@ +*.png +*.bin +*.pt +*.json diff --git a/examples/rlhf/README.md b/examples/rlhf/README.md new file mode 100644 index 00000000000..1ddca8dfb96 --- /dev/null +++ b/examples/rlhf/README.md @@ -0,0 +1,45 @@ +# RLHF example + +This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts. + +## Getting started + +Make sure you have PyTorch 2.0 installed. You can find installation instructions [here](https://pytorch.org/get-started/locally/). + +From this directory, you can install extra requirements for running these examples with + +```sh +pip install -r requirements.txt +``` + +## Training the models +### Training the transformer + +Once the data has been prepared, you can train the GPT model. + +```sh +python train.py +``` + +Default configuration can be found in `config/train.yaml`, and any option can be overridden with command-line arguments, for example to run the training script with a different batch size + +```sh +python train.py --batch_size=128 +``` +> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback + +### Training the reward model + +Next you can train the reward model with + +```sh +python train_reward.py +``` + +### Training the final model with RLHF + +To train the final model run + +```sh +python train_rlhf.py +``` diff --git a/examples/rlhf/config/train.yaml b/examples/rlhf/config/train.yaml new file mode 100644 index 00000000000..6d27088902f --- /dev/null +++ b/examples/rlhf/config/train.yaml @@ -0,0 +1,30 @@ +io: + eval_interval: 200 + log_interval: 50 + eval_iters: 100 +data: + batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint + out_dir: ./out + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +train: + grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 + max_iters: 5000 # total number of training iterations + gradient_accumulation_steps: 2 # used to simulate larger batch sizes + always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir + decay_lr: True # whether to decay the learning rate + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 1.0e-1 + betas: [0.9, 0.95] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 5000 # maximum number of iterations + eta_min: 1.0e-6 # minimum learning rate +sys: + device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks + dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/config/train_reward.yaml b/examples/rlhf/config/train_reward.yaml new file mode 100644 index 00000000000..a5523b75fe2 --- /dev/null +++ b/examples/rlhf/config/train_reward.yaml @@ -0,0 +1,32 @@ +io: + eval_interval: 200 + log_interval: 50 + eval_iters: 100 +data: + batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: ./out + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +reward_model: + out_dir: ./out_reward + init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward +train: + grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 + max_iters: 20000 # total number of training iterations + gradient_accumulation_steps: 2 # used to simulate larger batch sizes + always_save_checkpoint: False # if True, always save a checkpoint after each eval + decay_lr: False # whether to decay the learning rate + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 1.0e-1 + betas: [0.9, 0.95] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 20000 + eta_min: 1.0e-6 +sys: + device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks + dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/config/train_rlhf.yaml b/examples/rlhf/config/train_rlhf.yaml new file mode 100644 index 00000000000..0aac2d83acd --- /dev/null +++ b/examples/rlhf/config/train_rlhf.yaml @@ -0,0 +1,36 @@ +io: + eval_interval: 6 + log_interval: 1 + eval_iters: 10 +data: + batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: ./out + out_dir: ./out_rlhf + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +reward_model: + name_or_path: ./out_reward +train: + grad_clip: 1.0 + max_epochs: 1000 # total number of training iterations + always_save_checkpoint: True # if True, always save a checkpoint after each eval + decay_lr: True + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 5.0e-5 + weight_decay: 0.0 # 01 + betas: [0.9, 0.999] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size + eta_min: 5.0e-6 + ppo: + episode_length: 50 + ppo_batch_size: 16 + ppo_num_epochs: 3 + num_rollouts_per_epoch: 32 +sys: + device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks + dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/data/__init__.py b/examples/rlhf/data/__init__.py new file mode 100644 index 00000000000..433c23452f2 --- /dev/null +++ b/examples/rlhf/data/__init__.py @@ -0,0 +1,3 @@ +from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr + +__all__ = ["get_prompt_dataloader_tldr"] diff --git a/examples/rlhf/models/__init__.py b/examples/rlhf/models/__init__.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/examples/rlhf/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py new file mode 100644 index 00000000000..e514cf9b248 --- /dev/null +++ b/examples/rlhf/models/actor_critic.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.actors import LMActorCritic +from torchrl.modules.tensordict_module.common import VmapModule + +from .transformer import init_transformer + +__all__ = ["init_actor_critic"] + + +def init_actor_critic(transformer_name_or_path, dropout, device, compile_): + base_model = init_transformer( + transformer_name_or_path, + dropout, + device, + as_tensordictmodule=False, + compile_=compile_, + inference=True, + ) + model = LMActorCritic(base_model) + model.to(device) + model.eval() + actor = model.get_policy_operator() + critic = model.get_value_operator() + critic_head = model.get_value_head() + + return actor, VmapModule(critic), critic_head, base_model diff --git a/examples/rlhf/models/reward.py b/examples/rlhf/models/reward.py new file mode 100644 index 00000000000..ce84c727dd4 --- /dev/null +++ b/examples/rlhf/models/reward.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.nn import TensorDictModule + +from torchrl.modules.models.rlhf import GPT2RewardModel + + +def init_reward_model( + transformer_path=None, reward_model_path=None, device=None, compile_=False +): + if not ((transformer_path is None) ^ (reward_model_path is None)): + raise ValueError( + "Exactly one of transformer_path or reward_model_path should be specified" + ) + if transformer_path is not None: + model = GPT2RewardModel(transformer_path) + else: + model = GPT2RewardModel.from_pretrained(reward_model_path) + + model.to(device) + if compile_: + print("Compiling the reward model...") + model = torch.compile(model) + + model = TensorDictModule( + model, + in_keys=["input_ids", "attention_mask"], + out_keys=["rewards", "end_scores"], + ) + return model diff --git a/examples/rlhf/models/transformer.py b/examples/rlhf/models/transformer.py new file mode 100644 index 00000000000..cde8ce568ae --- /dev/null +++ b/examples/rlhf/models/transformer.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from tensordict.nn import TensorDictModule +from transformers import GPT2LMHeadModel + + +def init_transformer( + name_or_path, + dropout, + device, + compile_, + as_tensordictmodule=True, + inference=False, +): + model_kwargs = { + "resid_pdrop": dropout, + "embd_pdrop": dropout, + "attn_pdrop": dropout, + "summary_first_dropout": dropout, + } + model = GPT2LMHeadModel.from_pretrained( + name_or_path, return_dict=False, **model_kwargs + ) + model.to(device) + + if compile_: + # TODO: logging instead of printing? + print("Compiling transformer model...") + model = torch.compile(model) + + if as_tensordictmodule: + model = TensorDictModule( + model, + in_keys={ + "input_ids": "input_ids", + "attention_mask": "attention_mask", + "labels": "labels", + }, + out_keys=["logits"] if inference else ["loss", "logits"], + ) + return model diff --git a/examples/rlhf/requirements.txt b/examples/rlhf/requirements.txt new file mode 100644 index 00000000000..9bff1b48453 --- /dev/null +++ b/examples/rlhf/requirements.txt @@ -0,0 +1,11 @@ +datasets +hydra-core +matplotlib +numpy +PyYAML +requests +tiktoken +tqdm +transformers +git+https://github.com/pytorch/rl +git+https://github.com/pytorch-labs/tensordict diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py new file mode 100644 index 00000000000..fe624213ada --- /dev/null +++ b/examples/rlhf/train.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train the transformer model. Configurable via config/train.yaml, but any argument can +also be overridden at the command line. + +To run on a single GPU, example: +$ python train.py --batch_size=32 --compile=False +""" +import time + +import hydra +import torch +from models.transformer import init_transformer +from torch.optim.lr_scheduler import CosineAnnealingLR + +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.prompt import PromptData +from utils import get_file_logger, resolve_name_or_path, setup + + +def create_loss_estimator(eval_iters, ctx): + # helps estimate an arbitrarily accurate loss over either split using many batches + + @torch.no_grad() + def estimate_loss(model, dataloader): + model.eval() + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + batch.batch_size = [] + with ctx: + model(batch) + losses[k] = batch.loss.item() + model.train() + return losses.mean() + + return estimate_loss + + +@hydra.main(version_base="1.1", config_path="config", config_name="train") +def main(cfg): + loss_logger = get_file_logger("loss_logger", "transformer_loss_logger.log") + + data_cfg = cfg.data + model_cfg = cfg.model + train_cfg = cfg.train + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + out_dir = model_cfg.out_dir + + grad_clip = train_cfg.grad_clip + max_iters = train_cfg.max_iters + always_save_checkpoint = train_cfg.always_save_checkpoint + gradient_accumulation_steps = train_cfg.gradient_accumulation_steps + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device=device, dtype=dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="valid", + ) + + model = init_transformer( + resolve_name_or_path(model_cfg.name_or_path), + model_cfg.dropout, + device, + compile_=compile_, + ) + optimizer = torch.optim.AdamW(model.parameters(), **train_cfg.optimizer) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) + estimate_loss = create_loss_estimator(eval_iters, ctx) + + best_val_loss = float("inf") + + t0 = time.time() + next_batch = next(train_loader) # fetch the very first batch + for it in range(1, max_iters + 1): + for _ in range(gradient_accumulation_steps): + batch = next_batch + # TODO: can we handle this better with a differently structured tensorclass? + batch.batch_size = [] + with ctx: + model(batch) + # immediately async prefetch next batch while model is doing the forward pass on the GPU + next_batch = next(train_loader) + # backward pass, with gradient scaling if training in fp16 + scaler.scale(batch.loss).backward() + + # clip the gradient + if grad_clip != 0.0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + # step the optimizer and scaler if training in fp16 + scaler.step(optimizer) + scaler.update() + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) + + # update learning rate + if scheduler is not None: + scheduler.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if it % eval_interval == 0: + # evaluate the loss on train/val sets and write checkpoints + train_loss = estimate_loss(model, train_loader) + val_loss = estimate_loss(model, val_loader) + msg = f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}" + print(msg) + loss_logger.info(msg) + if val_loss < best_val_loss or always_save_checkpoint: + best_val_loss = val_loss + if it > 0: + msg = f"saving checkpoint to {out_dir}" + print(msg) + loss_logger.info(msg) + model.module.save_pretrained(out_dir) + elif it % log_interval == 0: + # loss as float. note: this is a CPU-GPU sync point + loss = batch.loss.item() + msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms" + print(msg) + loss_logger.info(msg) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py new file mode 100644 index 00000000000..850c6d92f1b --- /dev/null +++ b/examples/rlhf/train_reward.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import time + +import hydra +import torch +from models.reward import init_reward_model +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.reward import PairwiseDataset +from utils import get_file_logger, resolve_name_or_path, setup + + +def _accuracy(chosen_end_scores, rejected_end_scores): + return ( + sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) + ).item() + + +# TODO: eliminate redundant repeated definition +# helps estimate an arbitrarily accurate loss over either split using many batches +def create_loss_estimator(eval_iters, ctx): + @torch.no_grad() + def estimate_loss(model, dataloader): + model.eval() + losses = torch.zeros(eval_iters) + accs = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + with ctx: + model(batch.chosen_data) + model(batch.rejected_data) + losses[k] = model.compute_reward_loss( + batch.chosen_data, batch.rejected_data + ).item() + accs[k] = _accuracy( + batch.chosen_data.end_scores, batch.rejected_data.end_scores + ) + model.train() + return losses.mean(), accs.mean() + + return estimate_loss + + +@hydra.main(version_base="1.1", config_path="config", config_name="train_reward") +def main(cfg): + loss_logger = get_file_logger("loss_logger", "reward_loss_logger.log") + + data_cfg = cfg.data + model_cfg = cfg.model + reward_model_cfg = cfg.reward_model + train_cfg = cfg.train + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + reward_out_dir = reward_model_cfg.out_dir + + max_iters = train_cfg.max_iters + always_save_checkpoint = train_cfg.always_save_checkpoint + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device=device, dtype=dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PairwiseDataset, + device, + dataset_name="CarperAI/openai_summarize_comparisons", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PairwiseDataset, + device, + dataset_name="CarperAI/openai_summarize_comparisons", + split="valid1", + ) + + if reward_model_cfg.init_from == "resume": + model = init_reward_model( + reward_model_path=resolve_name_or_path(reward_model_cfg.out_dir), + device=device, + compile_=compile_, + ) + else: + model = init_reward_model( + transformer_path=resolve_name_or_path(model_cfg.name_or_path), + device=device, + compile_=compile_, + ) + # Freeze the first 70% of the hidden layers of the reward model backbone + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + # ######## INIT TRAINING FUNCTIONS ######## + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], **train_cfg.optimizer + ) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + estimate_loss = create_loss_estimator(eval_iters, ctx) + + best_val_loss = float("inf") + + t0 = time.time() + for it in range(1, max_iters + 1): + batch = next(train_loader) + + with ctx: + model(batch.chosen_data) + model(batch.rejected_data) + optimizer.zero_grad(set_to_none=True) + loss = model.compute_reward_loss(batch.chosen_data, batch.rejected_data) + loss.backward() + optimizer.step() + if scheduler is not None: + scheduler.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if it % eval_interval == 0: + val_loss, val_acc = estimate_loss(model, val_loader) + train_loss, train_acc = estimate_loss(model, train_loader) + + msg = ( + f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}, " + f"{train_acc=:.4f}, {val_acc=:.4f}" + ) + print(msg) + loss_logger.info(msg) + if val_loss < best_val_loss or always_save_checkpoint: + best_val_loss = val_loss + if it > 0: + msg = f"saving checkpoint to {reward_out_dir}" + print(msg) + loss_logger.info(msg) + model.module.save_pretrained(reward_out_dir) + elif it % log_interval == 0: + loss = loss.item() + acc = _accuracy( + batch.chosen_data.end_scores, batch.rejected_data.end_scores + ) + msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms" + print(msg) + loss_logger.info(msg) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py new file mode 100644 index 00000000000..4226bad3160 --- /dev/null +++ b/examples/rlhf/train_rlhf.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from copy import deepcopy + +import numpy as np +import torch + +import wandb +from models.actor_critic import init_actor_critic +from models.reward import init_reward_model + +from omegaconf import OmegaConf +from torch.optim.lr_scheduler import CosineAnnealingLR + +from torchrl.data import LazyTensorStorage +from torchrl.data.replay_buffers import ( + SamplerWithoutReplacement, + TensorDictReplayBuffer, +) +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.prompt import PromptData +from torchrl.data.rlhf.utils import RolloutFromModel + +from torchrl.objectives import ClipPPOLoss +from torchrl.objectives.value import GAE +from tqdm import tqdm +from transformers import GenerationConfig, GPT2Tokenizer +from utils import get_file_logger, resolve_name_or_path, setup + + +def flatten_td(td): + # our tensordict has shape [B, T] where B = batch_size and T = trajectory length + # some trajectories may have stopped (reached EOS) before generating T tokens + # this function truncates and concatenates the trajectories, resulting in a + # tensordict that has shape [N] where N <= B * T. + done = td["next", "done"] + mask = torch.zeros_like(done) + mask[..., 1:, :] = done[..., :-1, :] # shift by one + mask = ~mask.cumsum(-2).bool().squeeze() + return td[mask] + + +class AdaptiveKLController: + """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" + Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 + Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py + """ + + def __init__(self, init_kl_coef: float, target: float, horizon: int): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current: float, n_steps: int): + """Returns adaptively updated KL coefficient, βₜ₊₁. + Arguments: + current: The current KL value between the newest policy and the initial policy. + """ + proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult # βₜ₊₁ + return self.value + + +def create_reward_estimator( + eval_iters, episode_length, reward_model, batch, ctx, logger=None, ref_model=None +): + """Create a function to estimate the reward via sampling. + + This function creates a new function which, given a model and a dataloader, will + perform multiple rollouts using the model and data sampled from the dataloader then + average the accumulated rewards. + + For debugging purposes, we also generate responses to a fixed prompt so that the + quality of the model can be visually assessed during training. + """ + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + + test_rindex = batch.prompt_rindex[0] + test_prompt_ids = batch.input_ids[:1, :test_rindex] + test_label_ids = batch.input_ids[:1, test_rindex:] + generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length + ) + test_prompt = tokenizer.decode(test_prompt_ids[0, :test_rindex].tolist()) + test_label = tokenizer.decode( + test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() + ) + _, test_label_reward = reward_model( + input_ids=batch.input_ids[:1], attention_mask=batch.attention_mask[:1] + ) + + @torch.no_grad() + def estimate_reward(model, dataloader): + rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + rewards = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + # NOTE: disable kl for evaluation + td = rollout_from_model.rollout_from_data(batch, kl_coef=0.0) + rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item() + test_reward = rewards.mean() + + if logger: + response_ids = model.generate( + input_ids=test_prompt_ids, generation_config=generation_config + ) + with ctx: + _, response_reward = reward_model( + input_ids=response_ids, + attention_mask=(response_ids != tokenizer.pad_token_id).to( + torch.int64 + ), + ) + reward = (response_reward - test_label_reward).item() + response_ids = response_ids[0, test_rindex:] + response = tokenizer.decode( + response_ids[response_ids != tokenizer.eos_token_id].tolist() + ) + string_to_write = ( + f"Query:\n{test_prompt}\n" + f"Response:\n{response}\n" + f"Actual response:\n{test_label}\n" + f"{reward=:4.4f}, " + f"{test_reward=:4.4f}\n" + f"====================================================\n" + ) + logger.info(string_to_write) + + return test_reward + + return estimate_reward + + +# @hydra.main(version_base="1.1", config_path="config", config_name="train_rlhf") +def main(): + cfg = OmegaConf.load("config/train_rlhf.yaml") + wandb.init( + # set the wandb project where this run will be logged + project="rlhf-training", + # track hyperparameters and run metadata + config=cfg, + ) + query_logger = get_file_logger("query_logger", "rlhf_query_logger.log") + val_reward_logger = get_file_logger("val_reward_logger", "rlhf_valid_rewards.log") + + data_cfg = cfg.data + model_cfg = cfg.model + reward_model_cfg = cfg.reward_model + train_cfg = cfg.train + ppo_cfg = train_cfg.ppo + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + + rlhf_out_dir = model_cfg.out_dir + transformer_name_or_path = model_cfg.name_or_path + dropout = model_cfg.dropout + + batch_size = data_cfg.batch_size + + grad_clip = train_cfg.grad_clip + max_epochs = train_cfg.max_epochs + always_save_checkpoint = train_cfg.always_save_checkpoint + + episode_length = ppo_cfg.episode_length + ppo_batch_size = ppo_cfg.ppo_batch_size + ppo_num_epochs = ppo_cfg.ppo_num_epochs + num_rollouts_per_epoch = ppo_cfg.num_rollouts_per_epoch + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device, dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="valid", + ) + + actor, critic, critic_head, model = init_actor_critic( + resolve_name_or_path(transformer_name_or_path), dropout, device, compile_ + ) + ref_model = deepcopy(model).to("cuda:1") + ref_model.requires_grad_(False) + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + reward_model = init_reward_model( + reward_model_path=resolve_name_or_path(reward_model_cfg.name_or_path), + device=device, + compile_=compile_, + ) + reward_model.eval() + reward_model.requires_grad_(False) + + adv_fn = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True, shifted=True + ) + loss_fn = ClipPPOLoss(actor, critic_head) + + test_prompt = next(val_loader) + estimate_reward = create_reward_estimator( + eval_iters, + episode_length, + reward_model, + test_prompt, + ctx, + logger=query_logger, + ref_model=ref_model, + ) + + optimizer = torch.optim.AdamW( + [p for p in loss_fn.parameters() if p.requires_grad], **train_cfg.optimizer + ) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(episode_length * num_rollouts_per_epoch), + batch_size=episode_length * batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + ) + rb_ppo = TensorDictReplayBuffer( + storage=LazyTensorStorage(episode_length * batch_size), + batch_size=ppo_batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + ) + + rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + + best_val_reward = float("-inf") + it = 0 # it is equivalent to batch_size number of episodes + with tqdm(total=int(max_epochs * num_rollouts_per_epoch / batch_size)) as pbar: + for _epoch in range(1, max_epochs + 1): + rb.empty() + rollout_rewards = [] + rollout_kl = [] + kl_controller = AdaptiveKLController(0.1, 6, 10000) + for _ in range(0, num_rollouts_per_epoch, batch_size): + batch = next(train_loader) + td = rollout_from_model.rollout_from_data( + batch, kl_coef=kl_controller.value + ) + with torch.no_grad(), ctx: + # moving this to within epoch + adv_fn(td) + # it's possible we didn't fill the replay buffer in the last iteration if + # generation stopped early, so we empty first before repopulating + rb.extend(flatten_td(td)) + done = td.get(("next", "done")) + next_reward = td.get(("next", "reward_raw"))[done] + next_kl = td.get(("next", "reward_kl"))[done] + rollout_rewards.append(next_reward.mean().cpu().item()) + rollout_kl.append(next_kl.mean().cpu().item()) + rollout_reward = torch.tensor(rollout_rewards).mean().cpu().item() + rollout_kl_reward = torch.tensor(rollout_kl).mean().cpu().item() + # recover true kl + rollout_kl = -rollout_kl_reward / kl_controller.value + kl_controller.update(rollout_kl, num_rollouts_per_epoch / batch_size) + + # FIXME: THIS PPO CYCLE WAS DIFFERENT wrt trlx. @tcbegley please double check + # they sample batch_size from rb and then do minibatches ppo_batch_size within + if it % log_interval == 0: + val_reward_logger.info( + f"TRAIN: {it=}: {rollout_reward=:.4f} {rollout_kl_reward=:.4f} {rollout_kl=:.4f}" + ) + wandb.log( + { + "rollout_reward": rollout_reward, + "rollout_kl_reward": rollout_kl_reward, + "rollout_kl": rollout_kl, + }, + step=it, + ) + pbar.set_description(f"TRAIN: {it=}: {rollout_reward=:.4f}") + + for batch in rb: + rb_ppo.empty() + rb_ppo.extend(batch) + for _ in range(ppo_num_epochs): # PPO epochs + optimizer.zero_grad() + # why don't we optimize at each step? Is accumulating grads better? + # usually more small steps is better than a giant one + for minibatch in rb_ppo: # GO over RB + minibatch = minibatch.to(device, non_blocking=True) + with ctx: + loss_vals = loss_fn(minibatch) + loss_val = sum( + value + for key, value in loss_vals.items() + if key.startswith("loss") + ) + loss_val.backward() + torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), grad_clip) + optimizer.step() + if scheduler is not None: + scheduler.step() + it += 1 + pbar.update(1) + if it % eval_interval == 0: + val_reward = estimate_reward(model, val_loader) + val_reward_logger.info(f"VALID: {it=}: {val_reward=:.4f}") + wandb.log({"val_reward": val_reward}, step=it) + pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}") + if val_reward > best_val_reward or always_save_checkpoint: + best_val_reward = val_reward + if it > 0: + val_reward_logger.info( + f"saving checkpoint to {rlhf_out_dir}" + ) + model.save_pretrained(rlhf_out_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py new file mode 100644 index 00000000000..aedb0501091 --- /dev/null +++ b/examples/rlhf/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +from contextlib import nullcontext + +import torch +import torch._dynamo +from hydra.utils import to_absolute_path + + +def resolve_name_or_path(name_or_path): + """Hydra changes the working directory, so we need to absolutify paths.""" + if name_or_path.startswith("./") or name_or_path.startswith("/"): + return to_absolute_path(name_or_path) + return name_or_path + + +def get_file_logger(name, filename, level=logging.DEBUG): + """ + Set up logger that will log to the given filename. + """ + logger = logging.getLogger(name) + handler = logging.FileHandler(filename) + handler.setFormatter( + # logging.Formatter("%(asctime)s, %(name)s %(levelname)s %(message)s") + logging.Formatter("%(asctime)s - %(message)s") + ) + logger.addHandler(handler) + logger.setLevel(level) + return logger + + +def setup(device, dtype): + """ + Set manual seed, configure backend and autocasting. + """ + torch.manual_seed(1337) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + torch._dynamo.config.cache_size_limit = 256 + + if "cuda" not in device: + return nullcontext() + + return torch.amp.autocast(device_type="cuda", dtype=getattr(torch, dtype)) From ef3f76f65205002863294194a8557a69f2ec9bee Mon Sep 17 00:00:00 2001 From: Alessandro Pietro Bardelli Date: Tue, 27 Jun 2023 09:52:26 +0000 Subject: [PATCH 02/23] add VmapModule and from_lmhead_model method --- test/test_actors.py | 64 +++++++++++++++++++++ test/test_tensordictmodules.py | 12 ++++ torchrl/modules/tensordict_module/actors.py | 47 +++++++++++++++ torchrl/modules/tensordict_module/common.py | 49 +++++++++++++++- 4 files changed, 171 insertions(+), 1 deletion(-) diff --git a/test/test_actors.py b/test/test_actors.py index f3591e20628..aef2a6bedb7 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -18,6 +18,7 @@ MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, ) +from torchrl.data.rlhf.dataset import _has_transformers from torchrl.modules import MLP, SafeModule from torchrl.modules.tensordict_module.actors import ( _process_action_space_spec, @@ -25,6 +26,7 @@ DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, + LMHeadActorValueOperator, ProbabilisticActor, QValueActor, QValueHook, @@ -561,6 +563,68 @@ def test_actorcritic(device): ) == len(policy_params) +@pytest.mark.skipif(not _has_transformers, reason="missing dependencies") +@pytest.mark.parametrize("device", get_default_devices()) +def test_lmhead_actorvalueoperator(device): + from transformers import AutoModelForCausalLM + + base_model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict=False) + aco = LMHeadActorValueOperator(base_model) + + # check common + assert aco.module[0][0].module is base_model.transformer + assert aco.module[0][1].in_keys == ["x"] + assert aco.module[0][1].out_keys == ["x"] + + # check actor + assert aco.module[1].in_keys == ["x"] + assert aco.module[1].out_keys == ["logits", "action", "sample_log_prob"] + assert aco.module[1][0].module is base_model.lm_head + + # check critic + assert aco.module[2].in_keys == ["x"] + assert aco.module[2].out_keys == ["state_value"] + assert isinstance(aco.module[2].module, nn.Linear) + assert aco.module[2].module.in_features == base_model.transformer.embed_dim + assert aco.module[2].module.out_features == 1 + + td = TensorDict( + source={ + "input_ids": torch.randint(50257, (4, 3)), + "attention_mask": torch.ones((4, 3)), + }, + batch_size=[ + 4, + ], + ).to(device) + td_total = aco(td.clone()) + policy_op = aco.get_policy_operator() + td_policy = policy_op(td.clone()) + value_op = aco.get_value_operator() + td_value = value_op(td) + torch.testing.assert_close(td_total.get("action"), td_policy.get("action")) + torch.testing.assert_close( + td_total.get("sample_log_prob"), td_policy.get("sample_log_prob") + ) + torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value")) + + value_params = set( + list(aco.get_value_operator().parameters()) + list(aco.module[0].parameters()) + ) + value_params2 = set(value_op.parameters()) + assert len(value_params.difference(value_params2)) == 0 and len( + value_params.intersection(value_params2) + ) == len(value_params) + + policy_params = set( + list(aco.get_policy_operator().parameters()) + list(aco.module[0].parameters()) + ) + policy_params2 = set(policy_op.parameters()) + assert len(policy_params.difference(policy_params2)) == 0 and len( + policy_params.intersection(policy_params2) + ) == len(policy_params) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 4639e0ffffb..9fca32a9b12 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -20,6 +20,7 @@ from torchrl.modules.tensordict_module.common import ( ensure_tensordict_compatible, is_tensordict_compatible, + VmapModule, ) from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, @@ -27,6 +28,7 @@ ) from torchrl.modules.tensordict_module.sequence import SafeSequential + _has_functorch = False try: try: @@ -1727,6 +1729,16 @@ def test_multi_consecutive(self, shape): ) +def test_vmapmodule(): + lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) + sample_in = torch.ones((10, 3, 2)) + sample_in_td = TensorDict({"x": sample_in}, batch_size=[10]) + lam(sample_in) + vm = VmapModule(lam, 0) + vm(sample_in_td) + assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index efbddfb245a..720a0e50dfd 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -6,14 +6,17 @@ from typing import List, Optional, Sequence, Tuple, Union import torch + from tensordict import TensorDictBase from tensordict.nn import ( dispatch, TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, + TensorDictSequential, ) from torch import nn +from torch.distributions import Categorical from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.modules.models.models import DistributionalDQNnet @@ -1748,3 +1751,47 @@ def forward(self, tensordict): feature = low + (high - low) * (feature + 1) / 2 tensordict.set(out_key, feature) return tensordict + + +class LMHeadActorValueOperator(ActorValueOperator): + """Builds an Actor-Value operator from an huggingface-like *LMHeadModel. + + This method: + - takes as input an huggingface-like *LMHeadModel + - extracts the final linear layer uses it as a base layer of the actor_head and + adds the sampling layer + - uses the common transformer as common model + - adds a linear critic + + Args: + base_model (nn.Module): a torch model composed by a `.transformer` model and `.lm_head` linear layer + + Note: for more details please refer to :class:`~.ActorValueOperator`. + """ + + def __init__(self, base_model): + actor_head = base_model.lm_head + value_head = nn.Linear(actor_head.in_features, 1, bias=False) + common = TensorDictSequential( + TensorDictModule( + base_model.transformer, + in_keys={"input_ids": "input_ids", "attention_mask": "attention_mask"}, + out_keys=["x"], + ), + TensorDictModule(lambda x: x[:, -1, :], in_keys=["x"], out_keys=["x"]), + ) + actor_head = TensorDictModule(actor_head, in_keys=["x"], out_keys=["logits"]) + actor_head = SafeProbabilisticTensorDictSequential( + actor_head, + SafeProbabilisticModule( + in_keys=["logits"], + out_keys=["action"], + distribution_class=Categorical, + return_log_prob=True, + ), + ) + value_head = TensorDictModule( + value_head, in_keys=["x"], out_keys=["state_value"] + ) + + return super().__init__(common, actor_head, value_head) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 53e285e58f2..d8c22236af0 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -12,8 +12,9 @@ import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.tensordict import TensorDictBase + from torch import nn from torchrl.data.tensor_specs import CompositeSpec, TensorSpec @@ -401,3 +402,49 @@ def ensure_tensordict_compatible( if out_keys is not None: kwargs["out_keys"] = out_keys return wrapper_type(module, **kwargs) + + +class VmapModule(TensorDictModuleBase): + """A TensorDictModule wrapper to vmap over the input. + + It is intended to be used with modules that accept data with one less batch + dimension than the one provided. By using this wrapper, one can hide a + batch dimension and satisfy the wrapped module. + + Args: + module (TensorDictModuleBase): the module to vmap over. + vmap_dim (int, optional): the vmap input and output dim. + If none is provided, the last dimension of the tensordict is + assumed. + + .. note:: + + Since vmap requires to have control over the batch size of the input + this module does not support dispatched arguments + + Example: + lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) + sample_in = torch.ones((10,3,2)) + sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) + lam(sample_in) + vm = VmapModule(lam, 0) + vm(sample_in_td) + assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() + """ + + def __init__(self, module: TensorDictModuleBase, vmap_dim=None): + super().__init__() + self.in_keys = module.in_keys + self.out_keys = module.out_keys + self.module = module + self.vmap_dim = vmap_dim + + def forward(self, tensordict): + # TODO: there is a risk of segfault if input is not a tensordict. + # We should investigate (possibly prevent it c++ side?) + vmap_dim = self.vmap_dim + if vmap_dim is None: + ndim = tensordict.ndim + vmap_dim = ndim - 1 + td = torch.vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict) + return tensordict.update(td) From 02a909b2e219ed833249950b9653f8f2a8227ad5 Mon Sep 17 00:00:00 2001 From: Alessandro Pietro Bardelli Date: Wed, 28 Jun 2023 17:18:25 +0200 Subject: [PATCH 03/23] Update examples/rlhf/train_rlhf.py Co-authored-by: Vincent Moens --- examples/rlhf/train_rlhf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 4226bad3160..6372ad03a42 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -272,8 +272,9 @@ def main(): # generation stopped early, so we empty first before repopulating rb.extend(flatten_td(td)) done = td.get(("next", "done")) - next_reward = td.get(("next", "reward_raw"))[done] - next_kl = td.get(("next", "reward_kl"))[done] + td_done = td[done.view(td.shape)] + next_reward = td_done.get(("next", "reward_raw")) + next_kl = td_done.get(("next", "reward_kl")) rollout_rewards.append(next_reward.mean().cpu().item()) rollout_kl.append(next_kl.mean().cpu().item()) rollout_reward = torch.tensor(rollout_rewards).mean().cpu().item() From 953e4afa819a90d51eeda60c8aef04c44a267575 Mon Sep 17 00:00:00 2001 From: Alessandro Pietro Bardelli Date: Wed, 28 Jun 2023 16:12:25 +0000 Subject: [PATCH 04/23] addressing comments --- examples/rlhf/train_rlhf.py | 115 ++++++++++++------------------------ examples/rlhf/utils.py | 62 +++++++++++++++++++ 2 files changed, 101 insertions(+), 76 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 6372ad03a42..9b0461acc84 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -27,19 +27,7 @@ from torchrl.objectives.value import GAE from tqdm import tqdm from transformers import GenerationConfig, GPT2Tokenizer -from utils import get_file_logger, resolve_name_or_path, setup - - -def flatten_td(td): - # our tensordict has shape [B, T] where B = batch_size and T = trajectory length - # some trajectories may have stopped (reached EOS) before generating T tokens - # this function truncates and concatenates the trajectories, resulting in a - # tensordict that has shape [N] where N <= B * T. - done = td["next", "done"] - mask = torch.zeros_like(done) - mask[..., 1:, :] = done[..., :-1, :] # shift by one - mask = ~mask.cumsum(-2).bool().squeeze() - return td[mask] +from utils import flatten_td, get_file_logger, resolve_name_or_path, setup, TestPromptLogger class AdaptiveKLController: @@ -64,76 +52,49 @@ def update(self, current: float, n_steps: int): return self.value -def create_reward_estimator( - eval_iters, episode_length, reward_model, batch, ctx, logger=None, ref_model=None -): - """Create a function to estimate the reward via sampling. +class RewardEstimator: + """Create a class to estimate the reward via sampling. - This function creates a new function which, given a model and a dataloader, will + This class exposes a call method which, given a model and a dataloader, will perform multiple rollouts using the model and data sampled from the dataloader then average the accumulated rewards. For debugging purposes, we also generate responses to a fixed prompt so that the quality of the model can be visually assessed during training. + """ - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - tokenizer.pad_token = tokenizer.eos_token - - test_rindex = batch.prompt_rindex[0] - test_prompt_ids = batch.input_ids[:1, :test_rindex] - test_label_ids = batch.input_ids[:1, test_rindex:] - generation_config = GenerationConfig( - pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length - ) - test_prompt = tokenizer.decode(test_prompt_ids[0, :test_rindex].tolist()) - test_label = tokenizer.decode( - test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() - ) - _, test_label_reward = reward_model( - input_ids=batch.input_ids[:1], attention_mask=batch.attention_mask[:1] - ) + def __init__( + self, eval_iters, episode_length, reward_model, ref_model + ): + """ + Args: + eval_iters (int): number of batches on which we would like to estimate reward + + episode_length (int): max number of generated new tokens + + reward_model (GPT2RewardModel): reward model + + ref_model (GPT2LMHeadModel): original transformer model that it is used to + correctly compute kl component of reward. + """ + self.ref_model = ref_model + self.reward_model = reward_model + self.eval_iters = eval_iters + self.episode_length = episode_length @torch.no_grad() - def estimate_reward(model, dataloader): - rollout_from_model = RolloutFromModel(model, ref_model, reward_model) - rewards = torch.zeros(eval_iters) - for k in range(eval_iters): + def __call__(self, model, dataloader): + rollout_from_model = RolloutFromModel(model, self.ref_model, self.reward_model, max_new_tokens=self.episode_length) + rewards = torch.zeros(self.eval_iters) + for k in range(self.eval_iters): batch = next(dataloader) # NOTE: disable kl for evaluation td = rollout_from_model.rollout_from_data(batch, kl_coef=0.0) rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item() test_reward = rewards.mean() - if logger: - response_ids = model.generate( - input_ids=test_prompt_ids, generation_config=generation_config - ) - with ctx: - _, response_reward = reward_model( - input_ids=response_ids, - attention_mask=(response_ids != tokenizer.pad_token_id).to( - torch.int64 - ), - ) - reward = (response_reward - test_label_reward).item() - response_ids = response_ids[0, test_rindex:] - response = tokenizer.decode( - response_ids[response_ids != tokenizer.eos_token_id].tolist() - ) - string_to_write = ( - f"Query:\n{test_prompt}\n" - f"Response:\n{response}\n" - f"Actual response:\n{test_label}\n" - f"{reward=:4.4f}, " - f"{test_reward=:4.4f}\n" - f"====================================================\n" - ) - logger.info(string_to_write) - return test_reward - return estimate_reward - # @hydra.main(version_base="1.1", config_path="config", config_name="train_rlhf") def main(): @@ -220,22 +181,22 @@ def main(): loss_fn = ClipPPOLoss(actor, critic_head) test_prompt = next(val_loader) - estimate_reward = create_reward_estimator( + reward_estimator = RewardEstimator( eval_iters, episode_length, reward_model, - test_prompt, - ctx, - logger=query_logger, - ref_model=ref_model, + ref_model ) + prompt_logger = TestPromptLogger(test_prompt=test_prompt, reward_model=reward_model, logger=query_logger) + optimizer = torch.optim.AdamW( [p for p in loss_fn.parameters() if p.requires_grad], **train_cfg.optimizer ) scheduler = None if train_cfg.decay_lr: scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + kl_controller = AdaptiveKLController(0.1, 6, 10000) rb = TensorDictReplayBuffer( storage=LazyTensorStorage(episode_length * num_rollouts_per_epoch), @@ -255,11 +216,10 @@ def main(): best_val_reward = float("-inf") it = 0 # it is equivalent to batch_size number of episodes with tqdm(total=int(max_epochs * num_rollouts_per_epoch / batch_size)) as pbar: - for _epoch in range(1, max_epochs + 1): + for epoch in range(1, max_epochs + 1): rb.empty() rollout_rewards = [] rollout_kl = [] - kl_controller = AdaptiveKLController(0.1, 6, 10000) for _ in range(0, num_rollouts_per_epoch, batch_size): batch = next(train_loader) td = rollout_from_model.rollout_from_data( @@ -304,8 +264,6 @@ def main(): rb_ppo.extend(batch) for _ in range(ppo_num_epochs): # PPO epochs optimizer.zero_grad() - # why don't we optimize at each step? Is accumulating grads better? - # usually more small steps is better than a giant one for minibatch in rb_ppo: # GO over RB minibatch = minibatch.to(device, non_blocking=True) with ctx: @@ -323,7 +281,9 @@ def main(): it += 1 pbar.update(1) if it % eval_interval == 0: - val_reward = estimate_reward(model, val_loader) + with ctx: + val_reward = reward_estimator(model, val_loader) + prompt_logger.log(model) val_reward_logger.info(f"VALID: {it=}: {val_reward=:.4f}") wandb.log({"val_reward": val_reward}, step=it) pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}") @@ -334,6 +294,9 @@ def main(): f"saving checkpoint to {rlhf_out_dir}" ) model.save_pretrained(rlhf_out_dir) + + + if __name__ == "__main__": diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py index aedb0501091..54b4d519110 100644 --- a/examples/rlhf/utils.py +++ b/examples/rlhf/utils.py @@ -45,3 +45,65 @@ def setup(device, dtype): return nullcontext() return torch.amp.autocast(device_type="cuda", dtype=getattr(torch, dtype)) + + +def flatten_td(td): + # our tensordict has shape [B, T] where B = batch_size and T = trajectory length + # some trajectories may have stopped (reached EOS) before generating T tokens + # this function truncates and concatenates the trajectories, resulting in a + # tensordict that has shape [N] where N <= B * T. + done = td["next", "done"] + mask = torch.zeros_like(done) + mask[..., 1:, :] = done[..., :-1, :] # shift by one + mask = ~mask.cumsum(-2).bool().squeeze() + return td[mask] + +class TestPromptLogger(): + def __init__(self, test_prompt, reward_model, logger): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + test_rindex = test_prompt.prompt_rindex[0] + test_prompt_ids = test_prompt.input_ids[:1, :test_rindex] + test_label_ids = test_prompt.input_ids[:1, test_rindex:] + test_prompt = tokenizer.decode(test_prompt_ids[0, :test_rindex].tolist()) + test_label = tokenizer.decode( + test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() + ) + _, test_label_reward = reward_model( + input_ids=test_prompt.input_ids[:1], attention_mask=test_prompt.attention_mask[:1] + ) + self.generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length + ) + self.test_prompt_ids = test_prompt_ids + self.reward_model = reward_model + self.tokenizer = tokenizer + self.test_label_reward = test_label_reward + self.test_rindex = test_rindex + self.test_prompt = test_prompt + self.test_label = test_label + self.logger = logger + + def log(self, model): + response_ids = model.generate( + input_ids=self.test_prompt_ids, generation_config=self.generation_config + ) + _, response_reward = self.reward_model( + input_ids=response_ids, + attention_mask=(response_ids != self.tokenizer.pad_token_id).to( + torch.int64 + ), + ) + reward = (response_reward - self.test_label_reward).item() + response_ids = response_ids[0, self.test_rindex:] + response = self.tokenizer.decode( + response_ids[response_ids != self.tokenizer.eos_token_id].tolist() + ) + string_to_write = ( + f"Query:\n{self.test_prompt}\n" + f"Response:\n{response}\n" + f"Actual response:\n{self.test_label}\n" + f"{reward=:4.4f}\n" + f"====================================================\n" + ) + self.logger.info(string_to_write) \ No newline at end of file From f43faeacc4fc7a55b09398622ee13b19ad0a5d3f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 28 Jun 2023 19:15:05 +0100 Subject: [PATCH 05/23] Update torchrl/modules/tensordict_module/common.py --- torchrl/modules/tensordict_module/common.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index d8c22236af0..cd68b3ddc5b 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -423,13 +423,13 @@ class VmapModule(TensorDictModuleBase): this module does not support dispatched arguments Example: - lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) - sample_in = torch.ones((10,3,2)) - sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) - lam(sample_in) - vm = VmapModule(lam, 0) - vm(sample_in_td) - assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() + >>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) + >>> sample_in = torch.ones((10,3,2)) + >>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) + >>> lam(sample_in) + >>> vm = VmapModule(lam, 0) + >>> vm(sample_in_td) + >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() """ def __init__(self, module: TensorDictModuleBase, vmap_dim=None): From 69b05887e6430b77a89bbcb98e6f2f5643aad057 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 28 Jun 2023 19:15:11 +0100 Subject: [PATCH 06/23] Update torchrl/modules/tensordict_module/actors.py --- torchrl/modules/tensordict_module/actors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 720a0e50dfd..20f201e84f9 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1766,7 +1766,7 @@ class LMHeadActorValueOperator(ActorValueOperator): Args: base_model (nn.Module): a torch model composed by a `.transformer` model and `.lm_head` linear layer - Note: for more details please refer to :class:`~.ActorValueOperator`. + .. note:: For more details regarding the class construction, please refer to :class:`~.ActorValueOperator`. """ def __init__(self, base_model): From b6fecbb6bc266968235503fea33614ba1ed31722 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 26 Jun 2023 10:35:12 +0000 Subject: [PATCH 07/23] Add RolloutFromModel class --- torchrl/data/rlhf/utils.py | 295 +++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 torchrl/data/rlhf/utils.py diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py new file mode 100644 index 00000000000..ee75d8b8061 --- /dev/null +++ b/torchrl/data/rlhf/utils.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Tuple + +import torch + +from tensordict import TensorDict +from torch import Tensor +from torch.nn import functional as F + +from torchrl.data.rlhf.prompt import PromptData +from transformers import GenerationConfig + + +class RolloutFromModel: + """ + Args: + model (transformers.Transformer): the model to be used. Should have a + :meth:`generate` method. + ref_model (transformers.Transformer): a frozen version of ``model`` + where params are in their initial configuration. + reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given + input_ids and attention_mask, calculates rewards for each token and + end_scores (the reward for the final token in each sequence). + max_new_tokens (int, optional): the maximum length of the sequence. + Defaults to 50. + + """ + + EOS_TOKEN_ID = 50256 + + def __init__(self, model, ref_model, reward_model, max_new_tokens=50): + self.model = model + self.ref_model = ref_model + self.reward_model = reward_model + self.max_new_tokens = max_new_tokens + + def kl_step(self): + """Makes a step in the KL coefficient schedule.""" + pass + + @torch.no_grad() + def rollout_from_data(self, batch, kl_coef=0.1): + generated, log_probs, log_ratio = self.generate(batch) + return self.create_rollout_td( + batch, + generated, + self.reward_model, + log_probs, + log_ratio, + self.max_new_tokens, + kl_coef, + ) + + @torch.no_grad() + def create_rollout_td( + self, + batch, + generated, + reward_model, + log_probs, + log_ratio, + max_new_tokens=50, + kl_coef=0.1, + ): + """A TensorDict wrapper for generated data. + + This function takes a batch plus the generated tokens and replicates the tensordict + structure that would have been obtained from a rollout with a TorchRL env that + sampled one token each timestep. + + Args: + batch: + """ + rollout_generated = [] + for rindex, row in zip(batch.prompt_rindex, generated): + arange = torch.arange(row.shape[0], device=generated.device) + tokens = [] + for i in range(max_new_tokens + 1): + tokens.append( + torch.where( + arange < rindex + i, + row, + self.EOS_TOKEN_ID, + ) + ) + rollout_generated.append(torch.stack(tokens)) + rollout_generated = torch.stack(rollout_generated) + rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool() + + # done is True when we either first sample an EOS token or reach the maximum number + # of generated tokens + done_idx = torch.minimum( + (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex, + torch.tensor(max_new_tokens) - 1, + ) + done = torch.zeros( + done_idx.numel(), max_new_tokens, dtype=torch.bool, device=generated.device + ) + done = done.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) + + # the sequence of actions for each trajectory is just the generated token ids + action_idx = torch.arange(max_new_tokens, device=generated.device) + action_idx = action_idx + batch.prompt_rindex.unsqueeze(-1) + action = generated.gather(-1, action_idx) + + # calculate the reward for the finished sequence + _, end_scores = reward_model( + input_ids=rollout_generated[:, -1], + attention_mask=rollout_attention_mask[:, -1], + ) + _, end_scores_labels = reward_model( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask, + ) + # the reward is zero except for the timestep where we reached a stopping condition + clipped_scores = torch.clip(end_scores - end_scores_labels, -10, 10) + reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) + reward_raw = reward_raw * done + reward_kl = -kl_coef * log_ratio.unsqueeze(-1) + reward = reward_raw + reward_kl + td = { + "action": action, + "input_ids": rollout_generated[:, :-1].clone(), + "attention_mask": rollout_attention_mask[:, :-1].clone(), + "sample_log_prob": log_probs, + "next": { + "input_ids": rollout_generated[:, 1:].clone(), + "attention_mask": rollout_attention_mask[:, 1:].clone(), + "done": done, + "reward": reward, + "reward_raw": reward_raw, + "reward_kl": reward_kl, + }, + } + return TensorDict( + td, batch_size=done.shape[:2], device=generated.device + ).refine_names(..., "time") + + @classmethod + def _padded_right_to_left(cls, tensor, *, eos_token_id=None, dim=1): + if eos_token_id is None: + eos_token_id = cls.EOS_TOKEN_ID + mask = tensor != eos_token_id + out = torch.full_like(tensor, eos_token_id) + out[mask.flip(dim)] = tensor[mask] + return out + + @classmethod + def _padded_left_to_right( + cls, tensor, *, sequence_length=None, eos_token_id=None, dim=1 + ): + # some care must be taken here, because generated sequences may have both left + # and right padding, and also may not terminated early if all sequences in the + # batch reached EOS before reaching the token limit + if sequence_length is None: + sequence_length = tensor.size(dim) + if dim < 0: + dim = tensor.ndim + dim + if eos_token_id is None: + eos_token_id = cls.EOS_TOKEN_ID + mask = tensor != eos_token_id + # convert [0, 0, 1, 1, 0] to [0, 0, 1, 1, 1] to avoid right eos + mask = ~((~mask).to(torch.uint8).cumprod(dim).bool()) + shape = list(mask.shape) + shape[dim] = sequence_length + out = torch.full(torch.Size(shape), eos_token_id, device=tensor.device) + index = (slice(None),) * dim + (slice(tensor.size(dim)),) + out[index][mask.flip(dim)] = tensor[mask] + return out + + @property + def _default_conf(self): + return GenerationConfig( + pad_token_id=self.EOS_TOKEN_ID, + max_new_tokens=self.max_new_tokens, + return_dict_in_generate=True, + output_scores=True, + do_sample=True, + ) + + def _get_scores( + self, scores: Tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None + ): + scores = torch.stack(scores, 1) + if scores.shape[1] != self.max_new_tokens: + scores = F.pad( + scores, + (0, 0, 0, self.max_new_tokens - scores.shape[1]), + value=float("-inf"), + ) + scores = F.log_softmax(scores, dim=-1) + if use_max: + scores = scores.max(dim=-1).values + else: + index = generated_tokens.unsqueeze(-1) + scores = torch.gather(scores, dim=-1, index=index) + if pad_to is not None: + pad = pad_to - scores.shape[1] + return F.pad(scores, (0, pad), value=-float("inf")) + return scores + + @staticmethod + def logprobs_of_labels(logits, labels): + """Log probabilities of the labels. + + These are calculated from the logits.""" + logprobs = F.log_softmax(logits, dim=-1) + logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1)) + return logprobs_labels.squeeze(-1) + + @torch.no_grad() + def _log_ratio(self, generated, prompt_rindex): + # get the scores and normalise for log probabilities + attention_mask = (generated != self.EOS_TOKEN_ID).bool() + logits = self.model( + input_ids=generated, attention_mask=attention_mask, return_dict=True + ).logits + logprobs = self.logprobs_of_labels(logits[:, :-1], generated[:, 1:]) + ref_logits = self.ref_model( + input_ids=generated.to(self.ref_model.device), + attention_mask=attention_mask.to(self.ref_model.device), + return_dict=True, + ).logits.to(logits.device) + ref_logprobs = self.logprobs_of_labels(ref_logits[:, :-1], generated[:, 1:]) + log_ratio = logprobs - ref_logprobs + log_ratio = log_ratio.masked_fill(~attention_mask[:, :-1], 0) + log_ratio = torch.stack( + [ + row[rindex - 1 : rindex + self.max_new_tokens - 1] + for row, rindex in zip(log_ratio, prompt_rindex) + ], + dim=0, + ) + return log_ratio + + def _get_generated_tokens(self, generated, rindex): + # extracts the generated tokens from the full sequence of prompt + generated + idx = torch.arange(generated.shape[1], device=generated.device) + rindex = rindex.unsqueeze(-1) + mask = (idx >= rindex) & (idx < rindex + self.max_new_tokens) + return generated[mask].reshape(-1, self.max_new_tokens) + + @torch.no_grad() + def generate(self, batch: PromptData, generation_config=None): + """Generates a sequence of tokens from a batch of data sampled from the data collector. + + Args: + batch (PromptData): the data to be used. Must have ``input_ids`` + and ``prompt_rindex`` fields. + generation_config (GenerationConfig, optional): the configuration for the + call to generate. + + Returns: + generated (torch.Tensor): a [B x (Ti +To)] sequence of integers (tokens), + where Ti is the length of the input sequence and To is the length + of the generated sequence. + log_probs_gen: the log-probabilities of the token generated. + log_ratio: the log ratio between probabilities under the generative + model and the frozen version. + + """ + input_ids = batch.mask_label().input_ids + + # move padding tokens to left pad + # huggingface models expect left padding for generation + input_ids = self._padded_right_to_left(input_ids) + + # generate and capture scores + if generation_config is None: + generation_config = self._default_conf + + attention_mask = (input_ids != self.EOS_TOKEN_ID).bool() + outputs = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + ) + samples = outputs.sequences + + # we'll insert generated tokens into a tensor prepopulated with padding tokens, + # thereby moving back to right padding for reward model + generated = self._padded_left_to_right( + samples, + input_ids.shape[1] + self.max_new_tokens, + eos_token_id=self.EOS_TOKEN_ID, + ) + generated_tokens = self._get_generated_tokens(generated, batch.prompt_rindex) + # get the scores and normalise for log probabilities + log_probs_gen = self._get_scores(outputs.scores, generated_tokens) + + log_ratio = self._log_ratio(generated, batch.prompt_rindex) + return generated, log_probs_gen, log_ratio From bd8fbb62a06e1fd9bae2eb291be6be86969f82c5 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 26 Jun 2023 14:22:10 +0000 Subject: [PATCH 08/23] Add rollout tests --- test/test_rlhf.py | 157 +++++++++++++++++++++++++++++++++++++ torchrl/data/rlhf/utils.py | 64 ++++++++------- 2 files changed, 188 insertions(+), 33 deletions(-) diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 2060adaea86..7966d58d84f 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -24,6 +24,7 @@ from torchrl.data.rlhf.prompt import PromptData, PromptTensorDictTokenizer from torchrl.data.rlhf.reward import PairwiseDataset, pre_tokenization_hook from torchrl.modules.models.rlhf import GPT2RewardModel +from transformers import GPT2Config HERE = Path(__file__).parent @@ -385,6 +386,162 @@ def test_reward_model(tmpdir1, minidata_dir_comparison, batch_size, block_size, assert loss.shape == torch.Size([]) +class TestRollout: + kl_coef = 0.1 + + @staticmethod + def init_transformer( + dropout=0.1, + device="cpu", + as_tensordictmodule=True, + inference=False, + ): + from transformers import GPT2LMHeadModel + + model = GPT2LMHeadModel(GPT2Config()) + model.to(device) + + if as_tensordictmodule: + model = TensorDictModule( + model, + in_keys={ + "input_ids": "input_ids", + "attention_mask": "attention_mask", + "labels": "labels", + }, + out_keys=["logits"] if inference else ["loss", "logits"], + ) + return model + + def init_reward_model(self, device=None): + model = GPT2RewardModel() + model.to(device) + + model = TensorDictModule( + model, + in_keys=["input_ids", "attention_mask"], + out_keys=["rewards", "end_scores"], + ) + return model + + @property + def _dummy_batch(self): + return PromptData.from_tensordict( + TensorDict.load_memmap(f"{HERE}/datasets_mini/tldr_batch") + ) + + @property + def _model(self): + return self.init_transformer( + as_tensordictmodule=False, + inference=True, + ) + + @property + def _ref_model(self): + return self.init_transformer( + as_tensordictmodule=False, + inference=True, + ) + + @property + def _reward_model(self): + return self.init_reward_model() + + def _get_rollout_model(self, max_new_tokens=10): + return RolloutFromModel( + self._model, self._ref_model, self._reward_model, max_new_tokens + ) + + def test_padded_right_to_left(self): + x = torch.arange(12).view(3, 4) + x[0, -2:] = 100 + x[1, -1:] = 100 + x[2, -3:] = 100 + y = RolloutFromModel._padded_right_to_left(x, eos_token_id=100) + y_test = torch.tensor([[100, 100, 0, 1], [100, 4, 5, 6], [100, 100, 100, 8]]) + assert (y == y_test).all() + + @pytest.mark.parametrize("right_padded", [False, True]) + @pytest.mark.parametrize("sequence_length", [None, 5]) + def test_padded_left_to_right(self, right_padded, sequence_length): + x = torch.arange(12).view(3, 4) + x[0, :2] = 100 + x[1, :1] = 100 + x[2, :3] = 100 + if right_padded: + x[..., -1] = 100 + y = RolloutFromModel._padded_left_to_right( + x, eos_token_id=100, sequence_length=sequence_length + ) + if not right_padded: + y_test = torch.tensor( + [[2, 3, 100, 100], [5, 6, 7, 100], [11, 100, 100, 100]] + ) + else: + y_test = torch.tensor( + [[2, 100, 100, 100], [5, 6, 100, 100], [100, 100, 100, 100]] + ) + if sequence_length: + y_test = F.pad(y_test, (0, 1), value=100) + + assert (y == y_test).all() + + @pytest.mark.parametrize("batch_size", [2]) + @pytest.mark.parametrize("max_new_tokens", [10]) + @pytest.mark.parametrize("use_max", [True, False]) + def test_get_scores(self, batch_size, max_new_tokens, use_max): + scores = torch.arange(batch_size * max_new_tokens**2, dtype=torch.float).view( + batch_size, max_new_tokens, max_new_tokens + ) + gen_tokens = torch.arange(max_new_tokens).expand(1, max_new_tokens) + scores_comp = self._get_rollout_model( + max_new_tokens=max_new_tokens + )._get_scores(scores.unbind(1), generated_tokens=gen_tokens, use_max=use_max) + if not use_max: + assert ( + scores_comp.squeeze() + == torch.diagonal(scores.log_softmax(-1), 0, -2, -1).squeeze() + ).all() + else: + assert ( + scores_comp.squeeze() == scores.log_softmax(-1)[..., -1].squeeze() + ).all() + + def test_generate(self, max_new_tokens=10): + model = self._get_rollout_model(max_new_tokens) + batch = self._dummy_batch + generated, log_probs, log_ratio = model.generate(batch) + batch_size = batch.shape[0] + + assert generated.shape == torch.Size( + [batch_size, batch.input_ids.shape[1] + max_new_tokens] + ) + assert log_probs.shape == torch.Size([batch_size, max_new_tokens, 1]) + assert (log_probs <= 0).all().item() + assert log_ratio.shape == torch.Size([batch_size, max_new_tokens]) + + def test_rollout_from_data(self, max_new_tokens=10): + model = self._get_rollout_model(max_new_tokens) + batch = self._dummy_batch + td = model.rollout_from_data(batch) + batch_size = batch.shape[0] + + expected_keys = { + ("next", "attention_mask"), + ("next", "done"), + ("next", "input_ids"), + ("next", "reward"), + "action", + "attention_mask", + "input_ids", + "sample_log_prob", + } + keys = set(td.keys(True, True)) + assert all(key in keys for key in expected_keys) + assert td.batch_size == torch.Size([batch_size, max_new_tokens]) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index ee75d8b8061..90b082e254e 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -15,7 +15,8 @@ class RolloutFromModel: - """ + """A class for performing rollouts with language models. + Args: model (transformers.Transformer): the model to be used. Should have a :meth:`generate` method. @@ -26,7 +27,6 @@ class RolloutFromModel: end_scores (the reward for the final token in each sequence). max_new_tokens (int, optional): the maximum length of the sequence. Defaults to 50. - """ EOS_TOKEN_ID = 50256 @@ -44,41 +44,34 @@ def kl_step(self): @torch.no_grad() def rollout_from_data(self, batch, kl_coef=0.1): generated, log_probs, log_ratio = self.generate(batch) - return self.create_rollout_td( - batch, - generated, - self.reward_model, - log_probs, - log_ratio, - self.max_new_tokens, - kl_coef, - ) + return self.create_rollout_td(batch, generated, log_probs, log_ratio, kl_coef) @torch.no_grad() - def create_rollout_td( - self, - batch, - generated, - reward_model, - log_probs, - log_ratio, - max_new_tokens=50, - kl_coef=0.1, - ): + def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1): """A TensorDict wrapper for generated data. - This function takes a batch plus the generated tokens and replicates the tensordict - structure that would have been obtained from a rollout with a TorchRL env that - sampled one token each timestep. + This function takes a batch plus the generated tokens and replicates the + tensordict structure that would have been obtained from a rollout with a TorchRL + env that sampled one token each timestep. Args: - batch: + batch: A batch of data containing the original prompt together with a field + "rindex" indicating the right index of the prompt. + generated: The prompt together with generated tokens. This can be obtained + by calling the ``generate`` method. + log_probs: The log probabilities of the generated tokens. Can be obtained by + calling the ``generate`` method. + log_ratio: The log ratio of the probabilities of the generated tokens + according to the generative model and the reference model. Can be + obtained by calling the ``generate`` method. + kl_coef: Coefficient with which to multiply the KL term before subtracting + from the reward. """ rollout_generated = [] for rindex, row in zip(batch.prompt_rindex, generated): arange = torch.arange(row.shape[0], device=generated.device) tokens = [] - for i in range(max_new_tokens + 1): + for i in range(self.max_new_tokens + 1): tokens.append( torch.where( arange < rindex + i, @@ -94,24 +87,27 @@ def create_rollout_td( # of generated tokens done_idx = torch.minimum( (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex, - torch.tensor(max_new_tokens) - 1, + torch.tensor(self.max_new_tokens) - 1, ) done = torch.zeros( - done_idx.numel(), max_new_tokens, dtype=torch.bool, device=generated.device + done_idx.numel(), + self.max_new_tokens, + dtype=torch.bool, + device=generated.device, ) done = done.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) # the sequence of actions for each trajectory is just the generated token ids - action_idx = torch.arange(max_new_tokens, device=generated.device) + action_idx = torch.arange(self.max_new_tokens, device=generated.device) action_idx = action_idx + batch.prompt_rindex.unsqueeze(-1) action = generated.gather(-1, action_idx) # calculate the reward for the finished sequence - _, end_scores = reward_model( + _, end_scores = self.reward_model( input_ids=rollout_generated[:, -1], attention_mask=rollout_attention_mask[:, -1], ) - _, end_scores_labels = reward_model( + _, end_scores_labels = self.reward_model( input_ids=batch.input_ids, attention_mask=batch.attention_mask, ) @@ -206,7 +202,9 @@ def _get_scores( def logprobs_of_labels(logits, labels): """Log probabilities of the labels. - These are calculated from the logits.""" + These are calculated from the logits. The labels (token ids) are used to index + the logits along the relevant dimension. + """ logprobs = F.log_softmax(logits, dim=-1) logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1)) return logprobs_labels.squeeze(-1) @@ -284,7 +282,7 @@ def generate(self, batch: PromptData, generation_config=None): # thereby moving back to right padding for reward model generated = self._padded_left_to_right( samples, - input_ids.shape[1] + self.max_new_tokens, + sequence_length=input_ids.shape[1] + self.max_new_tokens, eos_token_id=self.EOS_TOKEN_ID, ) generated_tokens = self._get_generated_tokens(generated, batch.prompt_rindex) From 6fbb603756a7d2e7a1abc024ced36ee3fc09b2b4 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 26 Jun 2023 17:10:31 +0100 Subject: [PATCH 09/23] Apply suggestions from code review Co-authored-by: Alessandro Pietro Bardelli --- test/test_rlhf.py | 11 ++++++----- torchrl/data/rlhf/utils.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 7966d58d84f..5235a09db1a 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -413,7 +413,8 @@ def init_transformer( ) return model - def init_reward_model(self, device=None): + @staticmethod + def init_reward_model(device=None): model = GPT2RewardModel() model.to(device) @@ -424,8 +425,8 @@ def init_reward_model(self, device=None): ) return model - @property - def _dummy_batch(self): + @staticmethod + def _get_dummy_batch(): return PromptData.from_tensordict( TensorDict.load_memmap(f"{HERE}/datasets_mini/tldr_batch") ) @@ -510,7 +511,7 @@ def test_get_scores(self, batch_size, max_new_tokens, use_max): def test_generate(self, max_new_tokens=10): model = self._get_rollout_model(max_new_tokens) - batch = self._dummy_batch + batch = self._get_dummy_batch() generated, log_probs, log_ratio = model.generate(batch) batch_size = batch.shape[0] @@ -523,7 +524,7 @@ def test_generate(self, max_new_tokens=10): def test_rollout_from_data(self, max_new_tokens=10): model = self._get_rollout_model(max_new_tokens) - batch = self._dummy_batch + batch = self._get_dummy_batch() td = model.rollout_from_data(batch) batch_size = batch.shape[0] diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 90b082e254e..723a47d08d4 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -15,7 +15,8 @@ class RolloutFromModel: - """A class for performing rollouts with language models. + """A class for performing rollouts with causal language models, i.e., a model that takes in input tokenized text + and whose task is to predicting the next word in a sentence having read the n previous words. Args: model (transformers.Transformer): the model to be used. Should have a From 3e80a55c582ea01ae44f811fe33d2235c8cbb2df Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 26 Jun 2023 16:34:15 +0000 Subject: [PATCH 10/23] Address comments --- torchrl/data/rlhf/utils.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 723a47d08d4..2103ed340de 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -22,7 +22,9 @@ class RolloutFromModel: model (transformers.Transformer): the model to be used. Should have a :meth:`generate` method. ref_model (transformers.Transformer): a frozen version of ``model`` - where params are in their initial configuration. + where params are in their initial configuration. This is used to compute a + KL penalty for the reward, to stop the model from straying too far from the + reference model during training. reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given input_ids and attention_mask, calculates rewards for each token and end_scores (the reward for the final token in each sequence). @@ -67,6 +69,28 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) obtained by calling the ``generate`` method. kl_coef: Coefficient with which to multiply the KL term before subtracting from the reward. + + Returns: + A TensorDict with the following keys: + - "action": the sequence of actions (generated tokens) + - "input_ids": the input_ids passed to the generative model at each time + step. + - "attention_mask": the attention_masks passed to the generative model at + each time step + - "sample_log_prob": the log probability of each token during generation + - ("next", "input_ids"): the sequence of tokens after generation. Makes up + part of the inputs that will be used for generating the next token. + - ("next", "attention_mask"): updated attention_mask after token has been + generated. Passed to the generative model on the next time step + - ("next", "done"): Boolean array indicating whether we've reached a + terminal state (either because we generated EOS token or because we + reached the token limit) + - ("next", "reward"): The reward received at each time step + - ("next", "reward_raw"): The raw reward from the reward model, without the + KL term. This is mainly for debugging and logging, it is not used in + training + - ("next", "reward_kl"): The KL term from the reward. This is mainly for + debugging and logging, it is not used in training. """ rollout_generated = [] for rindex, row in zip(batch.prompt_rindex, generated): From 385ac90945d4438904c48326c7e5321cd269fe3b Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 26 Jun 2023 16:42:46 +0000 Subject: [PATCH 11/23] Docstring lint --- torchrl/data/rlhf/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 2103ed340de..0462a84cdaa 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -15,8 +15,11 @@ class RolloutFromModel: - """A class for performing rollouts with causal language models, i.e., a model that takes in input tokenized text - and whose task is to predicting the next word in a sentence having read the n previous words. + """A class for performing rollouts with causal language models. + + It is assumed that the model this class wraps takes in input tokenized text and + whose task is to predicting the next word in a sentence having read the n previous + words. Args: model (transformers.Transformer): the model to be used. Should have a From 8d0a15267df4025537627c95c28eb347b549a52e Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 27 Jun 2023 14:17:22 +0100 Subject: [PATCH 12/23] Apply suggestions from code review Co-authored-by: Vincent Moens --- torchrl/data/rlhf/utils.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 0462a84cdaa..02ef36a131a 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -18,7 +18,7 @@ class RolloutFromModel: """A class for performing rollouts with causal language models. It is assumed that the model this class wraps takes in input tokenized text and - whose task is to predicting the next word in a sentence having read the n previous + whose task is to predict the next word in a sentence having read the n previous words. Args: @@ -29,7 +29,7 @@ class RolloutFromModel: KL penalty for the reward, to stop the model from straying too far from the reference model during training. reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given - input_ids and attention_mask, calculates rewards for each token and + ``input_ids`` and ``attention_mask``, calculates rewards for each token and end_scores (the reward for the final token in each sequence). max_new_tokens (int, optional): the maximum length of the sequence. Defaults to 50. @@ -61,38 +61,38 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) env that sampled one token each timestep. Args: - batch: A batch of data containing the original prompt together with a field + batch (TensorDict): A batch of data containing the original prompt together with a field "rindex" indicating the right index of the prompt. - generated: The prompt together with generated tokens. This can be obtained + generated (torch.Tensor): Tokenized prompt followed by generated tokens. This can be obtained by calling the ``generate`` method. - log_probs: The log probabilities of the generated tokens. Can be obtained by + log_probs (torch.Tensor): The log probabilities of the generated tokens. Can be obtained by calling the ``generate`` method. - log_ratio: The log ratio of the probabilities of the generated tokens + log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the ``generate`` method. - kl_coef: Coefficient with which to multiply the KL term before subtracting - from the reward. + kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting + from the reward. Defaults to 0.1. Returns: - A TensorDict with the following keys: - - "action": the sequence of actions (generated tokens) - - "input_ids": the input_ids passed to the generative model at each time + A :class:`~tensordict.TensorDict` with the following keys: + - ``"action"``: the sequence of actions (generated tokens) + - ``"input_ids"``: the input_ids passed to the generative model at each time step. - - "attention_mask": the attention_masks passed to the generative model at + - ``"attention_mask"``: the attention_masks passed to the generative model at each time step - - "sample_log_prob": the log probability of each token during generation - - ("next", "input_ids"): the sequence of tokens after generation. Makes up + - ``"sample_log_prob"``: the log probability of each token during generation + - ``("next", "input_ids")``: the sequence of tokens after generation. Makes up part of the inputs that will be used for generating the next token. - - ("next", "attention_mask"): updated attention_mask after token has been + - ``("next", "attention_mask")``: updated attention_mask after token has been generated. Passed to the generative model on the next time step - - ("next", "done"): Boolean array indicating whether we've reached a + - ``("next", "done")``: Boolean array indicating whether we've reached a terminal state (either because we generated EOS token or because we reached the token limit) - - ("next", "reward"): The reward received at each time step - - ("next", "reward_raw"): The raw reward from the reward model, without the + - ``("next", "reward")``: The reward received at each time step + - ``("next", "reward_raw")``: The raw reward from the reward model, without the KL term. This is mainly for debugging and logging, it is not used in training - - ("next", "reward_kl"): The KL term from the reward. This is mainly for + - ``("next", "reward_kl")``: The KL term from the reward. This is mainly for debugging and logging, it is not used in training. """ rollout_generated = [] From fcddc97e00dfc0ae3bb0fe1e52b3bde540375cd7 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 27 Jun 2023 16:08:59 +0100 Subject: [PATCH 13/23] Address comments --- torchrl/data/rlhf/utils.py | 136 +++++++++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 37 deletions(-) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 02ef36a131a..5baeddf968a 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -17,7 +17,7 @@ class RolloutFromModel: """A class for performing rollouts with causal language models. - It is assumed that the model this class wraps takes in input tokenized text and + It is assumed that the model this class wraps takes as input tokenized text and whose task is to predict the next word in a sentence having read the n previous words. @@ -33,19 +33,69 @@ class RolloutFromModel: end_scores (the reward for the final token in each sequence). max_new_tokens (int, optional): the maximum length of the sequence. Defaults to 50. + score_clip (float, optional): Scores from the reward model are clipped to the + range ``(-score_clip, score_clip)``. Defaults to 10. + + Examples: + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules.models.rlhf import GPT2RewardModel + >>> from torchrl.data.rlhf.utils import RolloutFromModel + >>> from torchrl.data.rlhf.dataset import get_dataloader + >>> from torchrl.data.rlhf.prompt import PromptData + >>> from transformers import GPT2LMHeadModel + >>> + >>> dl = get_dataloader( + ... batch_size=4, + ... block_size=550, + ... tensorclass_type=PromptData, + ... device="cpu", + ... dataset_name="CarperAI/openai_summarize_tldr", + ... ) + >>> model = GPT2LMHeadModel.from_pretrained("gpt2") + >>> # we load ref_model with random weights so it differs from model + >>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class()) + >>> reward_model = GPT2RewardModel(model_path="gpt2") + >>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + >>> + >>> batch = next(dl) + >>> rollout = rollout_from_model.rollout_from_data(batch) + >>> rollout + TensorDict( + fields={ + action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), + input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), + input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), + reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), + reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 50]), + device=cpu, + is_shared=False), + sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 50]), + device=cpu, + is_shared=False) """ EOS_TOKEN_ID = 50256 - def __init__(self, model, ref_model, reward_model, max_new_tokens=50): + def __init__( + self, model, ref_model, reward_model, max_new_tokens=50, score_clip=10.0 + ): self.model = model self.ref_model = ref_model self.reward_model = reward_model self.max_new_tokens = max_new_tokens + self.score_clip = score_clip def kl_step(self): """Makes a step in the KL coefficient schedule.""" - pass + raise NotImplementedError @torch.no_grad() def rollout_from_data(self, batch, kl_coef=0.1): @@ -95,22 +145,54 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) - ``("next", "reward_kl")``: The KL term from the reward. This is mainly for debugging and logging, it is not used in training. """ + rollout_generated = self._get_rollout_generated(generated, batch) + rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool() + + done = self._get_done_status(generated, batch) + action = self._get_action(generated, batch) + end_scores, end_scores_labels = self._get_end_scores( + rollout_generated, rollout_attention_mask, batch + ) + + # the reward is zero except for the timestep where we reached a stopping condition + clipped_scores = torch.clip( + end_scores - end_scores_labels, -self.score_clip, self.score_clip + ) + reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) + reward_raw = reward_raw * done + reward_kl = -kl_coef * log_ratio.unsqueeze(-1) + reward = reward_raw + reward_kl + td = { + "action": action, + "input_ids": rollout_generated[:, :-1].clone(), + "attention_mask": rollout_attention_mask[:, :-1].clone(), + "sample_log_prob": log_probs, + "next": { + "input_ids": rollout_generated[:, 1:].clone(), + "attention_mask": rollout_attention_mask[:, 1:].clone(), + "done": done, + "reward": reward, + "reward_raw": reward_raw, + "reward_kl": reward_kl, + }, + } + return TensorDict( + td, batch_size=done.shape[:2], device=generated.device + ).refine_names(..., "time") + + def _get_rollout_generated(self, generated, batch): + # stack the individual timesteps during generation into a single tensor rollout_generated = [] + arange = torch.arange(generated.shape[1], device=generated.device) for rindex, row in zip(batch.prompt_rindex, generated): - arange = torch.arange(row.shape[0], device=generated.device) tokens = [] for i in range(self.max_new_tokens + 1): - tokens.append( - torch.where( - arange < rindex + i, - row, - self.EOS_TOKEN_ID, - ) - ) + tokens.append(torch.where(arange < rindex + i, row, self.EOS_TOKEN_ID)) rollout_generated.append(torch.stack(tokens)) rollout_generated = torch.stack(rollout_generated) - rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool() + return rollout_generated + def _get_done_status(self, generated, batch): # done is True when we either first sample an EOS token or reach the maximum number # of generated tokens done_idx = torch.minimum( @@ -123,13 +205,15 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) dtype=torch.bool, device=generated.device, ) - done = done.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) + return done.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) + def _get_action(self, generated, batch): # the sequence of actions for each trajectory is just the generated token ids action_idx = torch.arange(self.max_new_tokens, device=generated.device) action_idx = action_idx + batch.prompt_rindex.unsqueeze(-1) - action = generated.gather(-1, action_idx) + return generated.gather(-1, action_idx) + def _get_end_scores(self, rollout_generated, rollout_attention_mask, batch): # calculate the reward for the finished sequence _, end_scores = self.reward_model( input_ids=rollout_generated[:, -1], @@ -139,29 +223,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) input_ids=batch.input_ids, attention_mask=batch.attention_mask, ) - # the reward is zero except for the timestep where we reached a stopping condition - clipped_scores = torch.clip(end_scores - end_scores_labels, -10, 10) - reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) - reward_raw = reward_raw * done - reward_kl = -kl_coef * log_ratio.unsqueeze(-1) - reward = reward_raw + reward_kl - td = { - "action": action, - "input_ids": rollout_generated[:, :-1].clone(), - "attention_mask": rollout_attention_mask[:, :-1].clone(), - "sample_log_prob": log_probs, - "next": { - "input_ids": rollout_generated[:, 1:].clone(), - "attention_mask": rollout_attention_mask[:, 1:].clone(), - "done": done, - "reward": reward, - "reward_raw": reward_raw, - "reward_kl": reward_kl, - }, - } - return TensorDict( - td, batch_size=done.shape[:2], device=generated.device - ).refine_names(..., "time") + return end_scores, end_scores_labels @classmethod def _padded_right_to_left(cls, tensor, *, eos_token_id=None, dim=1): From 5c7c72e5663b5d6ae2d91f977e9a71d665ed0ecc Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 28 Jun 2023 12:06:02 +0100 Subject: [PATCH 14/23] Fix tests --- test/test_rlhf.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 5235a09db1a..8fb240aeb7e 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -11,9 +11,11 @@ import numpy as np import pytest import torch +import torch.nn.functional as F from _utils_internal import get_default_devices from tensordict import is_tensor_collection, MemmapTensor, TensorDict, TensorDictBase +from tensordict.nn import TensorDictModule from torchrl.data.rlhf import TensorDictTokenizer from torchrl.data.rlhf.dataset import ( _has_datasets, @@ -23,6 +25,7 @@ ) from torchrl.data.rlhf.prompt import PromptData, PromptTensorDictTokenizer from torchrl.data.rlhf.reward import PairwiseDataset, pre_tokenization_hook +from torchrl.data.rlhf.utils import RolloutFromModel from torchrl.modules.models.rlhf import GPT2RewardModel from transformers import GPT2Config @@ -36,20 +39,29 @@ def tmpdir1(tmp_path_factory): @pytest.fixture(scope="session") def minidata_dir_comparison(tmp_path_factory): - dest = tmp_path_factory.mktemp("tldr") - dataset_path = f"{HERE}/assets/openai_summarize_comparisons.zip" + dest = tmp_path_factory.mktemp("comparisons") + dataset_path = HERE / "assets" / "openai_summarize_comparisons.zip" with zipfile.ZipFile(dataset_path, "r") as zip_ref: zip_ref.extractall(dest) - yield dest / Path(dataset_path).name[:-4] + yield dest / Path(dataset_path).stem @pytest.fixture(scope="session") def minidata_dir_tldr(tmp_path_factory): dest = tmp_path_factory.mktemp("tldr") - dataset_path = f"{HERE}/assets/openai_summarize_tldr.zip" + dataset_path = HERE / "assets" / "openai_summarize_tldr.zip" + with zipfile.ZipFile(dataset_path, "r") as zip_ref: + zip_ref.extractall(dest) + yield dest / Path(dataset_path).stem + + +@pytest.fixture(scope="session") +def tldr_batch_dir(tmp_path_factory): + dest = tmp_path_factory.mktemp("tldr_batch") + dataset_path = HERE / "assets" / "tldr_batch.zip" with zipfile.ZipFile(dataset_path, "r") as zip_ref: zip_ref.extractall(dest) - yield dest / Path(dataset_path).name[:-4] + yield dest / Path(dataset_path).stem @pytest.mark.skipif( @@ -390,12 +402,7 @@ class TestRollout: kl_coef = 0.1 @staticmethod - def init_transformer( - dropout=0.1, - device="cpu", - as_tensordictmodule=True, - inference=False, - ): + def init_transformer(device="cpu", as_tensordictmodule=True, inference=False): from transformers import GPT2LMHeadModel model = GPT2LMHeadModel(GPT2Config()) @@ -426,10 +433,8 @@ def init_reward_model(device=None): return model @staticmethod - def _get_dummy_batch(): - return PromptData.from_tensordict( - TensorDict.load_memmap(f"{HERE}/datasets_mini/tldr_batch") - ) + def _get_dummy_batch(batch_dir): + return PromptData.from_tensordict(TensorDict.load_memmap(batch_dir)) @property def _model(self): @@ -509,9 +514,9 @@ def test_get_scores(self, batch_size, max_new_tokens, use_max): scores_comp.squeeze() == scores.log_softmax(-1)[..., -1].squeeze() ).all() - def test_generate(self, max_new_tokens=10): + def test_generate(self, tldr_batch_dir, max_new_tokens=10): model = self._get_rollout_model(max_new_tokens) - batch = self._get_dummy_batch() + batch = self._get_dummy_batch(tldr_batch_dir) generated, log_probs, log_ratio = model.generate(batch) batch_size = batch.shape[0] @@ -522,9 +527,9 @@ def test_generate(self, max_new_tokens=10): assert (log_probs <= 0).all().item() assert log_ratio.shape == torch.Size([batch_size, max_new_tokens]) - def test_rollout_from_data(self, max_new_tokens=10): + def test_rollout_from_data(self, tldr_batch_dir, max_new_tokens=10): model = self._get_rollout_model(max_new_tokens) - batch = self._get_dummy_batch() + batch = self._get_dummy_batch(tldr_batch_dir) td = model.rollout_from_data(batch) batch_size = batch.shape[0] From 92d57574344f42707ec98dfd1af0b3f59d16ac05 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 28 Jun 2023 12:12:24 +0100 Subject: [PATCH 15/23] Handle missing transformers import --- test/test_rlhf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 8fb240aeb7e..e85b9862e8f 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -27,7 +27,6 @@ from torchrl.data.rlhf.reward import PairwiseDataset, pre_tokenization_hook from torchrl.data.rlhf.utils import RolloutFromModel from torchrl.modules.models.rlhf import GPT2RewardModel -from transformers import GPT2Config HERE = Path(__file__).parent @@ -398,12 +397,15 @@ def test_reward_model(tmpdir1, minidata_dir_comparison, batch_size, block_size, assert loss.shape == torch.Size([]) +@pytest.mark.skipif( + not (_has_transformers and _has_datasets), reason="missing dependencies" +) class TestRollout: kl_coef = 0.1 @staticmethod def init_transformer(device="cpu", as_tensordictmodule=True, inference=False): - from transformers import GPT2LMHeadModel + from transformers import GPT2LMHeadModel, GPT2Config model = GPT2LMHeadModel(GPT2Config()) model.to(device) From eec0eafcf000a3b321ad8c481db04e4e933975b5 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 28 Jun 2023 12:43:23 +0100 Subject: [PATCH 16/23] Import transformers locally --- test/test_rlhf.py | 2 +- torchrl/data/rlhf/utils.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_rlhf.py b/test/test_rlhf.py index e85b9862e8f..2e4ba8d66c2 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -405,7 +405,7 @@ class TestRollout: @staticmethod def init_transformer(device="cpu", as_tensordictmodule=True, inference=False): - from transformers import GPT2LMHeadModel, GPT2Config + from transformers import GPT2Config, GPT2LMHeadModel model = GPT2LMHeadModel(GPT2Config()) model.to(device) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 5baeddf968a..82a43c124c8 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib from typing import Tuple import torch @@ -11,8 +12,8 @@ from torch.nn import functional as F from torchrl.data.rlhf.prompt import PromptData -from transformers import GenerationConfig +_has_transformers = importlib.util.find_spec("transformers") is not None class RolloutFromModel: """A class for performing rollouts with causal language models. @@ -87,6 +88,11 @@ class RolloutFromModel: def __init__( self, model, ref_model, reward_model, max_new_tokens=50, score_clip=10.0 ): + if not _has_transformers: + raise ImportError( + "transformers module couldn't be found. Make sure it is installed in your " + "environment." + ) self.model = model self.ref_model = ref_model self.reward_model = reward_model @@ -259,6 +265,8 @@ def _padded_left_to_right( @property def _default_conf(self): + from transformers import GenerationConfig + return GenerationConfig( pad_token_id=self.EOS_TOKEN_ID, max_new_tokens=self.max_new_tokens, From 87501ea32444de853b6c8563312beca30503b1c2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 28 Jun 2023 20:07:19 +0200 Subject: [PATCH 17/23] lint --- torchrl/data/rlhf/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 82a43c124c8..2b22a7347dd 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -15,6 +15,7 @@ _has_transformers = importlib.util.find_spec("transformers") is not None + class RolloutFromModel: """A class for performing rollouts with causal language models. @@ -90,9 +91,9 @@ def __init__( ): if not _has_transformers: raise ImportError( - "transformers module couldn't be found. Make sure it is installed in your " - "environment." - ) + "transformers module couldn't be found. Make sure it is installed in your " + "environment." + ) self.model = model self.ref_model = ref_model self.reward_model = reward_model From 8b69e41c4b64efeaf20367bbc859e07bcaa0a51e Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 29 Jun 2023 13:27:29 +0000 Subject: [PATCH 18/23] lint --- examples/rlhf/README.md | 4 +-- examples/rlhf/models/actor_critic.py | 4 +-- examples/rlhf/train_rlhf.py | 45 ++++++++++++++++------------ examples/rlhf/utils.py | 36 +++++++++++----------- torchrl/data/rlhf/dataset.py | 1 + 5 files changed, 50 insertions(+), 40 deletions(-) diff --git a/examples/rlhf/README.md b/examples/rlhf/README.md index 1ddca8dfb96..7c8347a3c1d 100644 --- a/examples/rlhf/README.md +++ b/examples/rlhf/README.md @@ -30,7 +30,7 @@ python train.py --batch_size=128 ### Training the reward model -Next you can train the reward model with +Once you have completed supervised fine-tuning, copy the desired model checkpoint to `./out` or update the config to point `model.name_or_path` at the relevant checkpoint in the timestamped working directory created by Hydra. You can then train the reward model with ```sh python train_reward.py @@ -38,7 +38,7 @@ python train_reward.py ### Training the final model with RLHF -To train the final model run +Once again, make sure you have either updated the configuration to point `reward_model.name_or_path` at the relevant timestamped working directory, or copy the checkpoint to `./out_reward`. You can then train the final model by running ```sh python train_rlhf.py diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py index e514cf9b248..fff70deefb0 100644 --- a/examples/rlhf/models/actor_critic.py +++ b/examples/rlhf/models/actor_critic.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from torchrl.modules.tensordict_module.actors import LMActorCritic +from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator from torchrl.modules.tensordict_module.common import VmapModule from .transformer import init_transformer @@ -19,7 +19,7 @@ def init_actor_critic(transformer_name_or_path, dropout, device, compile_): compile_=compile_, inference=True, ) - model = LMActorCritic(base_model) + model = LMHeadActorValueOperator(base_model) model.to(device) model.eval() actor = model.get_policy_operator() diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 9b0461acc84..6b733addd74 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -27,7 +27,13 @@ from torchrl.objectives.value import GAE from tqdm import tqdm from transformers import GenerationConfig, GPT2Tokenizer -from utils import flatten_td, get_file_logger, resolve_name_or_path, setup, TestPromptLogger +from utils import ( + flatten_td, + get_file_logger, + resolve_name_or_path, + setup, + TestPromptLogger, +) class AdaptiveKLController: @@ -63,19 +69,18 @@ class RewardEstimator: quality of the model can be visually assessed during training. """ - def __init__( - self, eval_iters, episode_length, reward_model, ref_model - ): + + def __init__(self, eval_iters, episode_length, reward_model, ref_model): """ - Args: - eval_iters (int): number of batches on which we would like to estimate reward + Args: + eval_iters (int): number of batches on which we would like to estimate reward - episode_length (int): max number of generated new tokens + episode_length (int): max number of generated new tokens - reward_model (GPT2RewardModel): reward model + reward_model (GPT2RewardModel): reward model - ref_model (GPT2LMHeadModel): original transformer model that it is used to - correctly compute kl component of reward. + ref_model (GPT2LMHeadModel): original transformer model that it is used to + correctly compute kl component of reward. """ self.ref_model = ref_model self.reward_model = reward_model @@ -84,7 +89,9 @@ def __init__( @torch.no_grad() def __call__(self, model, dataloader): - rollout_from_model = RolloutFromModel(model, self.ref_model, self.reward_model, max_new_tokens=self.episode_length) + rollout_from_model = RolloutFromModel( + model, self.ref_model, self.reward_model, max_new_tokens=self.episode_length + ) rewards = torch.zeros(self.eval_iters) for k in range(self.eval_iters): batch = next(dataloader) @@ -182,13 +189,15 @@ def main(): test_prompt = next(val_loader) reward_estimator = RewardEstimator( - eval_iters, - episode_length, - reward_model, - ref_model + eval_iters, episode_length, reward_model, ref_model ) - prompt_logger = TestPromptLogger(test_prompt=test_prompt, reward_model=reward_model, logger=query_logger) + prompt_logger = TestPromptLogger( + batch=test_prompt, + reward_model=reward_model, + logger=query_logger, + episode_length=episode_length, + ) optimizer = torch.optim.AdamW( [p for p in loss_fn.parameters() if p.requires_grad], **train_cfg.optimizer @@ -266,6 +275,7 @@ def main(): optimizer.zero_grad() for minibatch in rb_ppo: # GO over RB minibatch = minibatch.to(device, non_blocking=True) + import ipdb; ipdb.set_trace() with ctx: loss_vals = loss_fn(minibatch) loss_val = sum( @@ -294,9 +304,6 @@ def main(): f"saving checkpoint to {rlhf_out_dir}" ) model.save_pretrained(rlhf_out_dir) - - - if __name__ == "__main__": diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py index 54b4d519110..637a6d4d8c6 100644 --- a/examples/rlhf/utils.py +++ b/examples/rlhf/utils.py @@ -8,6 +8,7 @@ import torch import torch._dynamo from hydra.utils import to_absolute_path +from transformers import GPT2Tokenizer, GenerationConfig def resolve_name_or_path(name_or_path): @@ -58,23 +59,24 @@ def flatten_td(td): mask = ~mask.cumsum(-2).bool().squeeze() return td[mask] -class TestPromptLogger(): - def __init__(self, test_prompt, reward_model, logger): + +class TestPromptLogger: + def __init__(self, batch, reward_model, logger, episode_length): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - test_rindex = test_prompt.prompt_rindex[0] - test_prompt_ids = test_prompt.input_ids[:1, :test_rindex] - test_label_ids = test_prompt.input_ids[:1, test_rindex:] + test_rindex = batch.prompt_rindex[0] + test_prompt_ids = batch.input_ids[:1, :test_rindex] + test_label_ids = batch.input_ids[:1, test_rindex:] test_prompt = tokenizer.decode(test_prompt_ids[0, :test_rindex].tolist()) test_label = tokenizer.decode( - test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() + test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() ) _, test_label_reward = reward_model( - input_ids=test_prompt.input_ids[:1], attention_mask=test_prompt.attention_mask[:1] + input_ids=batch.input_ids[:1], attention_mask=batch.attention_mask[:1] ) self.generation_config = GenerationConfig( - pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length - ) + pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length + ) self.test_prompt_ids = test_prompt_ids self.reward_model = reward_model self.tokenizer = tokenizer @@ -83,19 +85,19 @@ def __init__(self, test_prompt, reward_model, logger): self.test_prompt = test_prompt self.test_label = test_label self.logger = logger - + def log(self, model): response_ids = model.generate( input_ids=self.test_prompt_ids, generation_config=self.generation_config ) _, response_reward = self.reward_model( - input_ids=response_ids, - attention_mask=(response_ids != self.tokenizer.pad_token_id).to( - torch.int64 - ), - ) + input_ids=response_ids, + attention_mask=(response_ids != self.tokenizer.pad_token_id).to( + torch.int64 + ), + ) reward = (response_reward - self.test_label_reward).item() - response_ids = response_ids[0, self.test_rindex:] + response_ids = response_ids[0, self.test_rindex :] response = self.tokenizer.decode( response_ids[response_ids != self.tokenizer.eos_token_id].tolist() ) @@ -106,4 +108,4 @@ def log(self, model): f"{reward=:4.4f}\n" f"====================================================\n" ) - self.logger.info(string_to_write) \ No newline at end of file + self.logger.info(string_to_write) diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index 7a8e094ae25..e2d10b19139 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -106,6 +106,7 @@ def __init__( self.pre_tokenization_hook = pre_tokenization_hook self.root_dir = root_dir self.from_disk = from_disk + self.valid_size = valid_size if num_workers is None: num_workers = max(os.cpu_count() // 2, 1) self.num_workers = num_workers From 24eaa3a9f63046fabb6e7b7ebb16a1cd065776f2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 29 Jun 2023 13:45:57 +0000 Subject: [PATCH 19/23] Example bugfixes --- examples/rlhf/train_rlhf.py | 1 - torchrl/modules/tensordict_module/actors.py | 2 +- torchrl/objectives/ppo.py | 5 ----- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 6b733addd74..17b12c6bcf2 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -275,7 +275,6 @@ def main(): optimizer.zero_grad() for minibatch in rb_ppo: # GO over RB minibatch = minibatch.to(device, non_blocking=True) - import ipdb; ipdb.set_trace() with ctx: loss_vals = loss_fn(minibatch) loss_val = sum( diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 20f201e84f9..6a795ba6d14 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1794,4 +1794,4 @@ def __init__(self, base_model): value_head, in_keys=["x"], out_keys=["state_value"] ) - return super().__init__(common, actor_head, value_head) + super().__init__(common, actor_head, value_head) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index a4d728e7d8d..e0a3e09f605 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -634,11 +634,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() batch = log_weight.shape[0] - if not advantage.shape == log_weight.shape: - raise RuntimeError( - f"advantage.shape and log_weight.shape do not match (got {advantage.shape} " - f"and {log_weight.shape})" - ) gain1 = log_weight.exp() * advantage log_weight_clip = log_weight.clamp(*self._clip_bounds) From fba43a13f08947dc40d54115ca2c6b5b8978b91f Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 29 Jun 2023 14:24:21 +0000 Subject: [PATCH 20/23] Move KL controller logic --- examples/rlhf/train_rlhf.py | 49 +++++++------------- torchrl/data/rlhf/utils.py | 89 ++++++++++++++++++++++++++++++++++--- 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 17b12c6bcf2..e0fc3e20582 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from copy import deepcopy -import numpy as np import torch import wandb @@ -21,12 +20,15 @@ ) from torchrl.data.rlhf.dataset import get_dataloader from torchrl.data.rlhf.prompt import PromptData -from torchrl.data.rlhf.utils import RolloutFromModel +from torchrl.data.rlhf.utils import ( + RolloutFromModel, + ConstantKLController, + AdaptiveKLController, +) from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE from tqdm import tqdm -from transformers import GenerationConfig, GPT2Tokenizer from utils import ( flatten_td, get_file_logger, @@ -36,28 +38,6 @@ ) -class AdaptiveKLController: - """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" - Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 - Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py - """ - - def __init__(self, init_kl_coef: float, target: float, horizon: int): - self.value = init_kl_coef - self.target = target - self.horizon = horizon - - def update(self, current: float, n_steps: int): - """Returns adaptively updated KL coefficient, βₜ₊₁. - Arguments: - current: The current KL value between the newest policy and the initial policy. - """ - proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult # βₜ₊₁ - return self.value - - class RewardEstimator: """Create a class to estimate the reward via sampling. @@ -90,13 +70,16 @@ def __init__(self, eval_iters, episode_length, reward_model, ref_model): @torch.no_grad() def __call__(self, model, dataloader): rollout_from_model = RolloutFromModel( - model, self.ref_model, self.reward_model, max_new_tokens=self.episode_length + model, + self.ref_model, + self.reward_model, + kl_controller=ConstantKLController(0.0), # disable KL for evaluation + max_new_tokens=self.episode_length, ) rewards = torch.zeros(self.eval_iters) for k in range(self.eval_iters): batch = next(dataloader) - # NOTE: disable kl for evaluation - td = rollout_from_model.rollout_from_data(batch, kl_coef=0.0) + td = rollout_from_model.rollout_from_data(batch) rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item() test_reward = rewards.mean() @@ -220,7 +203,7 @@ def main(): prefetch=10, ) - rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + rollout_from_model = RolloutFromModel(model, ref_model, reward_model, kl_controller) best_val_reward = float("-inf") it = 0 # it is equivalent to batch_size number of episodes @@ -231,9 +214,7 @@ def main(): rollout_kl = [] for _ in range(0, num_rollouts_per_epoch, batch_size): batch = next(train_loader) - td = rollout_from_model.rollout_from_data( - batch, kl_coef=kl_controller.value - ) + td = rollout_from_model.rollout_from_data(batch) with torch.no_grad(), ctx: # moving this to within epoch adv_fn(td) @@ -249,8 +230,8 @@ def main(): rollout_reward = torch.tensor(rollout_rewards).mean().cpu().item() rollout_kl_reward = torch.tensor(rollout_kl).mean().cpu().item() # recover true kl - rollout_kl = -rollout_kl_reward / kl_controller.value - kl_controller.update(rollout_kl, num_rollouts_per_epoch / batch_size) + rollout_kl = -rollout_kl_reward / kl_controller.coef + rollout_from_model.kl_update(rollout_kl, num_rollouts_per_epoch / batch_size) # FIXME: THIS PPO CYCLE WAS DIFFERENT wrt trlx. @tcbegley please double check # they sample batch_size from rb and then do minibatches ppo_batch_size within diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 2b22a7347dd..c3d75d435c9 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -2,9 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import abc import importlib from typing import Tuple +import numpy as np import torch from tensordict import TensorDict @@ -16,6 +18,72 @@ _has_transformers = importlib.util.find_spec("transformers") is not None +class KLControllerBase(abc.ABC): + """Base class for KL controllers. + + Each controller must implement an update method that takes the current KL value and + the number of steps and updates the self.coef attribute, which will multiply + the KL during calculation of the reward. + """ + + @abc.abstractmethod + def update(self, kl_value: float, n_steps: int): + pass + + +class ConstantKLController(KLControllerBase): + """Constant KL Controller. + + This controller maintains a fixed coefficient no matter what values it is updated + with. + + Arguments: + coefficient (float): The coefficient to multiply KL with when calculating the + reward. + """ + + def __init__(self, coefficient): + self.coef = coefficient + + def update(self, kl_value: float, n_steps: int): + pass + + +class AdaptiveKLController(KLControllerBase): + """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". + + Arguments: + init_kl_coef (float): The starting value of the coefficient. + target (float): The target KL value. When the observed KL is smaller, the + coefficient is decreased, thereby relaxing the KL penalty in the training + objective and allowing the model to stray further from the reference model. + When the observed KL is greater than the target, the KL coefficient is + increased, thereby pulling the model back towards the reference model. + horizon (int): Scaling factor to control how aggressively we update the + coefficient. + + Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 + Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py + """ + + def __init__(self, init_kl_coef: float, target: float, horizon: int): + self.coef = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, kl_value: float, n_steps: int): + """Update ``self.coef`` adaptively. + + Arguments: + kl_value: The current KL value between the newest policy and the initial + policy. + n_steps: The number of training steps taken since last update. + """ + proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ + mult = 1 + proportional_error * n_steps / self.horizon + self.coef *= mult # βₜ₊₁ + + class RolloutFromModel: """A class for performing rollouts with causal language models. @@ -87,7 +155,13 @@ class RolloutFromModel: EOS_TOKEN_ID = 50256 def __init__( - self, model, ref_model, reward_model, max_new_tokens=50, score_clip=10.0 + self, + model, + ref_model, + reward_model, + kl_controller, + max_new_tokens=50, + score_clip=10.0, ): if not _has_transformers: raise ImportError( @@ -99,18 +173,19 @@ def __init__( self.reward_model = reward_model self.max_new_tokens = max_new_tokens self.score_clip = score_clip + self.kl_controller = kl_controller - def kl_step(self): + def kl_update(self, kl_value, n_steps): """Makes a step in the KL coefficient schedule.""" - raise NotImplementedError + self.kl_controller.update(kl_value, n_steps) @torch.no_grad() - def rollout_from_data(self, batch, kl_coef=0.1): + def rollout_from_data(self, batch): generated, log_probs, log_ratio = self.generate(batch) - return self.create_rollout_td(batch, generated, log_probs, log_ratio, kl_coef) + return self.create_rollout_td(batch, generated, log_probs, log_ratio) @torch.no_grad() - def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1): + def create_rollout_td(self, batch, generated, log_probs, log_ratio): """A TensorDict wrapper for generated data. This function takes a batch plus the generated tokens and replicates the @@ -167,7 +242,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) ) reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) reward_raw = reward_raw * done - reward_kl = -kl_coef * log_ratio.unsqueeze(-1) + reward_kl = -self.kl_controller.coef * log_ratio.unsqueeze(-1) reward = reward_raw + reward_kl td = { "action": action, From c07ac939081b8f183c772ff2aaf18d3d4d76a4de Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 4 Jul 2023 09:33:04 +0100 Subject: [PATCH 21/23] amend --- examples/rlhf/train_rlhf.py | 10 ++++++---- examples/rlhf/utils.py | 2 +- torchrl/data/rlhf/utils.py | 3 --- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index e0fc3e20582..5e6c09ee367 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -21,9 +21,9 @@ from torchrl.data.rlhf.dataset import get_dataloader from torchrl.data.rlhf.prompt import PromptData from torchrl.data.rlhf.utils import ( - RolloutFromModel, - ConstantKLController, AdaptiveKLController, + ConstantKLController, + RolloutFromModel, ) from torchrl.objectives import ClipPPOLoss @@ -208,7 +208,7 @@ def main(): best_val_reward = float("-inf") it = 0 # it is equivalent to batch_size number of episodes with tqdm(total=int(max_epochs * num_rollouts_per_epoch / batch_size)) as pbar: - for epoch in range(1, max_epochs + 1): + for _ in range(1, max_epochs + 1): rb.empty() rollout_rewards = [] rollout_kl = [] @@ -231,7 +231,9 @@ def main(): rollout_kl_reward = torch.tensor(rollout_kl).mean().cpu().item() # recover true kl rollout_kl = -rollout_kl_reward / kl_controller.coef - rollout_from_model.kl_update(rollout_kl, num_rollouts_per_epoch / batch_size) + rollout_from_model.kl_update( + rollout_kl, num_rollouts_per_epoch / batch_size + ) # FIXME: THIS PPO CYCLE WAS DIFFERENT wrt trlx. @tcbegley please double check # they sample batch_size from rb and then do minibatches ppo_batch_size within diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py index 637a6d4d8c6..766335badbf 100644 --- a/examples/rlhf/utils.py +++ b/examples/rlhf/utils.py @@ -8,7 +8,7 @@ import torch import torch._dynamo from hydra.utils import to_absolute_path -from transformers import GPT2Tokenizer, GenerationConfig +from transformers import GenerationConfig, GPT2Tokenizer def resolve_name_or_path(name_or_path): diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index a99919958b5..c3d75d435c9 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -2,9 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import importlib -from typing import Tuple - import abc import importlib from typing import Tuple From f463e0e273313c9eb361942b0fbf7e39e29acd1a Mon Sep 17 00:00:00 2001 From: Alessandro Pietro Bardelli Date: Tue, 4 Jul 2023 11:24:29 +0000 Subject: [PATCH 22/23] addressing comments about klcontroller --- examples/rlhf/train_rlhf.py | 8 ++++---- torchrl/data/rlhf/utils.py | 34 ++++++++++++++++++++-------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 5e6c09ee367..f199c69aa0d 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -73,7 +73,7 @@ def __call__(self, model, dataloader): model, self.ref_model, self.reward_model, - kl_controller=ConstantKLController(0.0), # disable KL for evaluation + kl_coef=0, # disable KL for evaluation max_new_tokens=self.episode_length, ) rewards = torch.zeros(self.eval_iters) @@ -188,7 +188,6 @@ def main(): scheduler = None if train_cfg.decay_lr: scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) - kl_controller = AdaptiveKLController(0.1, 6, 10000) rb = TensorDictReplayBuffer( storage=LazyTensorStorage(episode_length * num_rollouts_per_epoch), @@ -203,7 +202,8 @@ def main(): prefetch=10, ) - rollout_from_model = RolloutFromModel(model, ref_model, reward_model, kl_controller) + rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + kl_controller = AdaptiveKLController(rollout_from_model, 0.1, 6, 10000) best_val_reward = float("-inf") it = 0 # it is equivalent to batch_size number of episodes @@ -231,7 +231,7 @@ def main(): rollout_kl_reward = torch.tensor(rollout_kl).mean().cpu().item() # recover true kl rollout_kl = -rollout_kl_reward / kl_controller.coef - rollout_from_model.kl_update( + kl_controller.update( rollout_kl, num_rollouts_per_epoch / batch_size ) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index c3d75d435c9..fab7cf420fb 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -22,8 +22,8 @@ class KLControllerBase(abc.ABC): """Base class for KL controllers. Each controller must implement an update method that takes the current KL value and - the number of steps and updates the self.coef attribute, which will multiply - the KL during calculation of the reward. + the number of steps and updates the kl_coef attribute of the wrapped model, + which will multiply the KL during calculation of the reward. """ @abc.abstractmethod @@ -38,21 +38,27 @@ class ConstantKLController(KLControllerBase): with. Arguments: - coefficient (float): The coefficient to multiply KL with when calculating the + model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + kl_coef (float): The coefficient to multiply KL with when calculating the reward. """ - def __init__(self, coefficient): - self.coef = coefficient + def __init__(self, model, kl_coef): + self.model = model + if not hasattr(model, "kl_coef"): + raise AttributeError("Model input to ConstantKLController doesn't have attribute 'kl_coef'") + self.coef = kl_coef + self.model.kl_coef = self.coef def update(self, kl_value: float, n_steps: int): - pass + self.model.kl_coef = self.coef class AdaptiveKLController(KLControllerBase): """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". Arguments: + model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' init_kl_coef (float): The starting value of the coefficient. target (float): The target KL value. When the observed KL is smaller, the coefficient is decreased, thereby relaxing the KL penalty in the training @@ -66,10 +72,12 @@ class AdaptiveKLController(KLControllerBase): Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py """ - def __init__(self, init_kl_coef: float, target: float, horizon: int): + def __init__(self, model, init_kl_coef: float, target: float, horizon: int): + self.model = model self.coef = init_kl_coef self.target = target self.horizon = horizon + self.model.kl_coef = self.coef def update(self, kl_value: float, n_steps: int): """Update ``self.coef`` adaptively. @@ -82,6 +90,7 @@ def update(self, kl_value: float, n_steps: int): proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ mult = 1 + proportional_error * n_steps / self.horizon self.coef *= mult # βₜ₊₁ + self.model.kl_coef = self.coef class RolloutFromModel: @@ -101,6 +110,7 @@ class RolloutFromModel: reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given ``input_ids`` and ``attention_mask``, calculates rewards for each token and end_scores (the reward for the final token in each sequence). + kl_coef: (float, optional): initial kl coefficient. max_new_tokens (int, optional): the maximum length of the sequence. Defaults to 50. score_clip (float, optional): Scores from the reward model are clipped to the @@ -159,7 +169,7 @@ def __init__( model, ref_model, reward_model, - kl_controller, + kl_coef=0.1, max_new_tokens=50, score_clip=10.0, ): @@ -173,11 +183,7 @@ def __init__( self.reward_model = reward_model self.max_new_tokens = max_new_tokens self.score_clip = score_clip - self.kl_controller = kl_controller - - def kl_update(self, kl_value, n_steps): - """Makes a step in the KL coefficient schedule.""" - self.kl_controller.update(kl_value, n_steps) + self.kl_coef = kl_coef @torch.no_grad() def rollout_from_data(self, batch): @@ -242,7 +248,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): ) reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) reward_raw = reward_raw * done - reward_kl = -self.kl_controller.coef * log_ratio.unsqueeze(-1) + reward_kl = -self.kl_coef * log_ratio.unsqueeze(-1) reward = reward_raw + reward_kl td = { "action": action, From a9b94f088149021174b93d647fc97b4a61e33868 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 2 Oct 2023 09:48:16 -0400 Subject: [PATCH 23/23] amend --- torchrl/data/rlhf/utils.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 942afa4de76..7db09959b74 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -238,7 +238,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): rollout_generated = self._get_rollout_generated(generated, batch) rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool() - done = self._get_done_status(generated, batch) + done, terminated = self._get_done_status(generated, batch) action = self._get_action(generated, batch) end_scores, end_scores_labels = self._get_end_scores( rollout_generated, rollout_attention_mask, batch @@ -261,7 +261,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): "input_ids": rollout_generated[:, 1:].clone(), "attention_mask": rollout_attention_mask[:, 1:].clone(), "done": done, - "terminated": done.clone(), + "terminated": terminated, "reward": reward, "reward_raw": reward_raw, "reward_kl": reward_kl, @@ -286,18 +286,11 @@ def _get_rollout_generated(self, generated, batch): def _get_done_status(self, generated, batch): # done is True when we either first sample an EOS token or reach the maximum number # of generated tokens - # TODO: differentiate truncated and terminal here - done_idx = torch.minimum( - (generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex, - torch.tensor(self.max_new_tokens) - 1, - ) - done = torch.zeros( - done_idx.numel(), - self.max_new_tokens, - dtype=torch.bool, - device=generated.device, - ) - return done.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1) + terminated = generated == self.EOS_TOKEN_ID + terminated = terminated.int().cumsum(-1).bool() + done = terminated.clone() + done[..., self.max_new_tokens-1] = 1 + return done, terminated def _get_action(self, generated, batch): # the sequence of actions for each trajectory is just the generated token ids