Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] RLHF end to end example #1324

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f24281c
RLHF end2end example
apbard Jun 27, 2023
ef3f76f
add VmapModule and from_lmhead_model method
apbard Jun 27, 2023
02a909b
Update examples/rlhf/train_rlhf.py
apbard Jun 28, 2023
953e4af
addressing comments
apbard Jun 28, 2023
ffb8661
Merge remote-tracking branch 'origin/main' into rlhf-networks
vmoens Jun 28, 2023
f43faea
Update torchrl/modules/tensordict_module/common.py
vmoens Jun 28, 2023
69b0588
Update torchrl/modules/tensordict_module/actors.py
vmoens Jun 28, 2023
b6fecbb
Add RolloutFromModel class
tcbegley Jun 26, 2023
bd8fbb6
Add rollout tests
tcbegley Jun 26, 2023
6fbb603
Apply suggestions from code review
tcbegley Jun 26, 2023
3e80a55
Address comments
tcbegley Jun 26, 2023
385ac90
Docstring lint
tcbegley Jun 26, 2023
8d0a152
Apply suggestions from code review
tcbegley Jun 27, 2023
fcddc97
Address comments
tcbegley Jun 27, 2023
5c7c72e
Fix tests
tcbegley Jun 28, 2023
92d5757
Handle missing transformers import
tcbegley Jun 28, 2023
eec0eaf
Import transformers locally
tcbegley Jun 28, 2023
87501ea
lint
vmoens Jun 28, 2023
043fcf6
Merge branch 'rlhf-rollout' into rlhf-example
tcbegley Jun 29, 2023
3f53046
Merge branch 'rlhf-networks' into rlhf-example
tcbegley Jun 29, 2023
8b69e41
lint
tcbegley Jun 29, 2023
24eaa3a
Example bugfixes
tcbegley Jun 29, 2023
fba43a1
Move KL controller logic
tcbegley Jun 29, 2023
20fa920
Merge branch 'main' into rlhf-example
vmoens Jul 4, 2023
c07ac93
amend
vmoens Jul 4, 2023
f463e0e
addressing comments about klcontroller
apbard Jul 4, 2023
eac5374
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Sep 5, 2023
8d2dde7
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Oct 1, 2023
a2ba045
Merge branch 'main' into rlhf-example
vmoens Oct 2, 2023
a9b94f0
amend
vmoens Oct 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/rlhf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.png
*.bin
*.pt
*.json
45 changes: 45 additions & 0 deletions examples/rlhf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# RLHF example

This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts.

## Getting started

Make sure you have PyTorch 2.0 installed. You can find installation instructions [here](https://pytorch.org/get-started/locally/).

From this directory, you can install extra requirements for running these examples with

```sh
pip install -r requirements.txt
```

## Training the models
### Training the transformer

Once the data has been prepared, you can train the GPT model.

```sh
python train.py
```

Default configuration can be found in `config/train.yaml`, and any option can be overridden with command-line arguments, for example to run the training script with a different batch size

```sh
python train.py --batch_size=128
```
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback

### Training the reward model

Once you have completed supervised fine-tuning, copy the desired model checkpoint to `./out` or update the config to point `model.name_or_path` at the relevant checkpoint in the timestamped working directory created by Hydra. You can then train the reward model with

```sh
python train_reward.py
```

### Training the final model with RLHF

Once again, make sure you have either updated the configuration to point `reward_model.name_or_path` at the relevant timestamped working directory, or copy the checkpoint to `./out_reward`. You can then train the final model by running

```sh
python train_rlhf.py
```
30 changes: 30 additions & 0 deletions examples/rlhf/config/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
io:
eval_interval: 200
log_interval: 50
eval_iters: 100
data:
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size: 550
model:
name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint
out_dir: ./out
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
train:
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
max_iters: 5000 # total number of training iterations
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir
decay_lr: True # whether to decay the learning rate
optimizer:
# keyword arguments for torch.optim.AdamW
lr: 1.0e-5
weight_decay: 1.0e-1
betas: [0.9, 0.95]
scheduler:
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 5000 # maximum number of iterations
eta_min: 1.0e-6 # minimum learning rate
sys:
device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks
dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler
compile: True # use PyTorch 2.0 to compile the model to be faster
32 changes: 32 additions & 0 deletions examples/rlhf/config/train_reward.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
io:
eval_interval: 200
log_interval: 50
eval_iters: 100
data:
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size: 550
model:
name_or_path: ./out
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
reward_model:
out_dir: ./out_reward
init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward
train:
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
max_iters: 20000 # total number of training iterations
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
always_save_checkpoint: False # if True, always save a checkpoint after each eval
decay_lr: False # whether to decay the learning rate
optimizer:
# keyword arguments for torch.optim.AdamW
lr: 1.0e-5
weight_decay: 1.0e-1
betas: [0.9, 0.95]
scheduler:
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 20000
eta_min: 1.0e-6
sys:
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile: True # use PyTorch 2.0 to compile the model to be faster
36 changes: 36 additions & 0 deletions examples/rlhf/config/train_rlhf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
io:
eval_interval: 6
log_interval: 1
eval_iters: 10
data:
batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size: 550
model:
name_or_path: ./out
out_dir: ./out_rlhf
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
reward_model:
name_or_path: ./out_reward
train:
grad_clip: 1.0
max_epochs: 1000 # total number of training iterations
always_save_checkpoint: True # if True, always save a checkpoint after each eval
decay_lr: True
optimizer:
# keyword arguments for torch.optim.AdamW
lr: 5.0e-5
weight_decay: 0.0 # 01
betas: [0.9, 0.999]
scheduler:
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size
eta_min: 5.0e-6
ppo:
episode_length: 50
ppo_batch_size: 16
ppo_num_epochs: 3
num_rollouts_per_epoch: 32
sys:
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile: True # use PyTorch 2.0 to compile the model to be faster
3 changes: 3 additions & 0 deletions examples/rlhf/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr

__all__ = ["get_prompt_dataloader_tldr"]
4 changes: 4 additions & 0 deletions examples/rlhf/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
29 changes: 29 additions & 0 deletions examples/rlhf/models/actor_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator
from torchrl.modules.tensordict_module.common import VmapModule

from .transformer import init_transformer

__all__ = ["init_actor_critic"]


def init_actor_critic(transformer_name_or_path, dropout, device, compile_):
base_model = init_transformer(
transformer_name_or_path,
dropout,
device,
as_tensordictmodule=False,
compile_=compile_,
inference=True,
)
model = LMHeadActorValueOperator(base_model)
model.to(device)
model.eval()
actor = model.get_policy_operator()
critic = model.get_value_operator()
critic_head = model.get_value_head()

return actor, VmapModule(critic), critic_head, base_model
34 changes: 34 additions & 0 deletions examples/rlhf/models/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from tensordict.nn import TensorDictModule

from torchrl.modules.models.rlhf import GPT2RewardModel


def init_reward_model(
transformer_path=None, reward_model_path=None, device=None, compile_=False
):
if not ((transformer_path is None) ^ (reward_model_path is None)):
raise ValueError(
"Exactly one of transformer_path or reward_model_path should be specified"
)
if transformer_path is not None:
model = GPT2RewardModel(transformer_path)
else:
model = GPT2RewardModel.from_pretrained(reward_model_path)

model.to(device)
if compile_:
print("Compiling the reward model...")
model = torch.compile(model)

model = TensorDictModule(
model,
in_keys=["input_ids", "attention_mask"],
out_keys=["rewards", "end_scores"],
)
return model
44 changes: 44 additions & 0 deletions examples/rlhf/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from tensordict.nn import TensorDictModule
from transformers import GPT2LMHeadModel


def init_transformer(
name_or_path,
dropout,
device,
compile_,
as_tensordictmodule=True,
inference=False,
):
model_kwargs = {
"resid_pdrop": dropout,
"embd_pdrop": dropout,
"attn_pdrop": dropout,
"summary_first_dropout": dropout,
}
model = GPT2LMHeadModel.from_pretrained(
name_or_path, return_dict=False, **model_kwargs
)
model.to(device)

if compile_:
# TODO: logging instead of printing?
print("Compiling transformer model...")
model = torch.compile(model)

if as_tensordictmodule:
model = TensorDictModule(
model,
in_keys={
"input_ids": "input_ids",
"attention_mask": "attention_mask",
"labels": "labels",
},
out_keys=["logits"] if inference else ["loss", "logits"],
)
return model
11 changes: 11 additions & 0 deletions examples/rlhf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
datasets
hydra-core
matplotlib
numpy
PyYAML
requests
tiktoken
tqdm
transformers
git+https://github.com/pytorch/rl
git+https://github.com/pytorch-labs/tensordict
Loading
Loading