Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLHF with PPO #1005

Merged
merged 44 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
11d88a2
Refactoring TransformerDecoder and adding value-head transformers
SalmanMohammadi May 9, 2024
2849ec5
adding ppo config and recipe to registry
SalmanMohammadi May 10, 2024
f0c1410
Merge branch 'pytorch:main' into ppo
SalmanMohammadi May 12, 2024
57c67bf
implemented ppo recipe structure, advantage and return estimation, tr…
SalmanMohammadi May 15, 2024
03cba4b
finished first pass implementation of ppo. added tests for ppo loss
SalmanMohammadi May 15, 2024
f50f047
reverting changes
SalmanMohammadi May 15, 2024
b034af7
adding lora to ppo recipe, adding lora value head component and model…
SalmanMohammadi May 16, 2024
466b683
added lora training, added value head checkpointing and recipe resumi…
SalmanMohammadi May 19, 2024
928037d
removing test model builders, adding batched generation to ppo recipe…
SalmanMohammadi May 21, 2024
68b6162
fixing bug in _checkpointer.py
SalmanMohammadi May 21, 2024
65ca12a
Adding support for user-provided masks in attention
SalmanMohammadi May 30, 2024
9d8c5a8
Merge branch 'pytorch:main' into ppo
SalmanMohammadi May 31, 2024
b99102c
merging transformer custom masking, adding support for generation wit…
SalmanMohammadi Jun 4, 2024
a1cde1c
adding functionality for truncation in generation, and further tests …
SalmanMohammadi Jun 4, 2024
b032778
updated lora recipe to use custom generation
SalmanMohammadi Jun 6, 2024
f126e9a
Merge branch 'pytorch:main' into ppo
SalmanMohammadi Jun 6, 2024
04d514a
added support for correct truncation and padding of responses, added …
SalmanMohammadi Jun 7, 2024
4854908
added correct mask and position id trajectory generation, score rejec…
SalmanMohammadi Jun 8, 2024
c885833
bugfixing in ppo recipe. refactoring ppo_utils and tests to individua…
SalmanMohammadi Jun 8, 2024
57d57fa
updating ppo_utils namespace
SalmanMohammadi Jun 8, 2024
cce5548
fixing bug in collation, updating loss tests
SalmanMohammadi Jun 10, 2024
c289566
bugfixes in masking and indexing logprobs and values, added fixed kl …
SalmanMohammadi Jun 12, 2024
a3fa1ea
added loss and value masking
SalmanMohammadi Jun 14, 2024
c3db142
some refactoring, lots of testing and docs
SalmanMohammadi Jun 16, 2024
589bf7d
improved early training stability by adding value head init. from rew…
SalmanMohammadi Jun 16, 2024
346c30b
updating metrics
SalmanMohammadi Jun 18, 2024
2e9d779
reworking causal masking
SalmanMohammadi Jun 18, 2024
46b75be
freeing up memory after steps to avoid mem leaks
SalmanMohammadi Jun 18, 2024
0fd885e
Merge branch 'main' into ppo
SalmanMohammadi Jul 16, 2024
1942b0f
cleaning up; verifying results; switching to full finetune
SalmanMohammadi Jul 16, 2024
58d92ab
tidying up
SalmanMohammadi Jul 16, 2024
1fbb6dc
detaching losses for metric logging
SalmanMohammadi Jul 18, 2024
65ef9dc
removing 1b, merging main
SalmanMohammadi Jul 25, 2024
c7bbff1
merging
SalmanMohammadi Jul 25, 2024
1129f9e
deleting logits in loss
SalmanMohammadi Jul 29, 2024
fe87dfb
Merge branch 'main' into ppo
SalmanMohammadi Aug 2, 2024
662ab2c
cleaning conf
SalmanMohammadi Aug 2, 2024
76b124f
pYdOcLiNt
SalmanMohammadi Aug 2, 2024
dc4887c
downloading weights
SalmanMohammadi Aug 3, 2024
ef85dba
addressing comments
SalmanMohammadi Aug 5, 2024
fd87fe6
updating test
SalmanMohammadi Aug 5, 2024
ba365a8
let's finish this the way we started... together
SalmanMohammadi Aug 5, 2024
e76304c
Merge branch 'main' into ppo
SalmanMohammadi Aug 5, 2024
4e6be43
lInTiNG
SalmanMohammadi Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions recipes/configs/mistral/7B_lora_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Mistral 7B model
#
# This config uses hyperparameters based on small set of experiments and information
# available on various forums. These are not meant to replicate the numbers
# from the paper
#
# This config assumes that you've run the following command before launching
# this run:
# tune download mistralai/Mistral-7B-v0.1 --hf-token <HF_TOKEN> --output-dir /tmp/Mistral-7B-v0.1
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config mistral/7B_full_low_memory
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
path: ./target/weights/mistral_base/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: nvidia/HelpSteer
max_seq_len: 32
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
split: train
column: prompt

seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.mistral.lora_mistral_lm_with_value_head_7b
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

reward_model:
_component_: torchtune.models.mistral.mistral_classifier_7b


checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: ./target/weights/mistral_base/
checkpoint_files: [
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin"
]

recipe_checkpoint: null
adapter_checkpoint: null
output_dir: ${output_dir}/base/
model_type: LM_MISTRAL

reward_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: ./target/weights/mistral_reward/
checkpoint_files: [
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors"
]
output_dir: ${output_dir}/reward
model_type: MISTRAL_REWARD

resume_from_checkpoint: False
output_dir: target/full_7b

initialise_value_head_from_reward_model: True

# Fine-tuning arguments
batch_size: 64
num_steps: 10000
ppo_epochs: 2
ppo_batch_size: 2
ppo_backward_batch_size: 2
gradient_accumulation_steps: 1
whiten_rewards: False

# Generation arguments
forward_batch_size: 2
max_generated_tokens: 32
temperature: 0.7
top_k: null

# Reward masking args
truncate_after_tokens: null
penalise_no_eos: False
reward_penalty: -1.0

# KL controller arguments
# kl_controller:
# _component_: torchtune.utils.ppo_utils.AdaptiveKLController
# init_kl_coef: 0.15
# kl_target: 6
# kl_horizon: 10000
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

# or
kl_controller:
_component_: torchtune.utils.ppo_utils.FixedKLController
kl_coef: 0.05

optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 1.41e-5

loss:
_component_: torchtune.modules.loss.PPOLoss
gamma: 1
lmbda: 0.95
epsilon: 0.2
value_coeff: 0.1
value_clip_range: 0.1

# Training env
device: mps
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, does bf16 have different implications for training stability in a PPO training loop as compared to SFT?

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I'm not sure I have the expertise to answer this based on experience, but intuitively I'd say that there's several factors in PPO optimisation which try to help stabilise training and reduce variance in gradient updates, so the impact of reduced precision may be smaller than in SFT scenarios.

The PPO loss also isn't a "distance" like in SFT, so you may not have the same loss landscapes because gradient updates point in the direction of maximising the reward. This means the smoothness of the loss landscape is more related to the amount of variance in your trajectories e.g. if your generations are significantly different to eachother (due to e.g. generation args), your reward model isn't well-calibrated, or something as simple as your batch size being too small.

Empirically, the reference TRL results I compared against below used fp32, and my results were in bf16.


# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.StdoutLogger
log_dir: ${output_dir}

log_every_n_steps: 1
log_peak_memory_stats: False


profiler:
_component_: torchtune.utils.profiler
enabled: False
output_dir: ${output_dir}/torchtune_perf_tracing.json
Loading