Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📉 Add PEFT support for PPOTrainer #2344

Merged
merged 14 commits into from
Nov 18, 2024
Prev Previous commit
Next Next commit
Add ppo.py PEFT example
  • Loading branch information
ccs96307 committed Nov 11, 2024

Verified

This commit was signed with the committer’s verified signature.
jar-b Jared Baker
commit a99c7bf983d6eaabf3da87e5752d8825032c1eb1
38 changes: 34 additions & 4 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

import shutil

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
@@ -23,7 +24,15 @@
HfArgumentParser,
)

from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments
from trl import (
ModelConfig,
PPOConfig,
PPOTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


@@ -67,6 +76,20 @@
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
@@ -81,12 +104,18 @@
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)

peft_config = get_peft_config(model_config)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
else:
ref_policy = None

################
# Dataset
################
@@ -131,6 +160,7 @@ def tokenize(element):
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.train()

6 changes: 3 additions & 3 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
--dataset_test_split validation \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--output_dir models/minimal/ppo_tldr \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 64 \
--total_episodes 30000 \
@@ -52,7 +52,7 @@
--stop_token eos \
--response_length 53 \
--eval_strategy steps \
--eval_steps 10
--eval_steps 100

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
@@ -70,7 +70,7 @@
--missing_eos_penalty 1.0 \
--stop_token eos \
--eval_strategy steps \
--eval_steps 10
--eval_steps 100
"""


30 changes: 30 additions & 0 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
@@ -65,3 +65,33 @@ def test_num_train_epochs():
shell=True,
check=True,
)


def test_peft_support():
command = """\
python examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--lora_target_modules query_key_value dense
"""
if platform.system() == "Windows":
# windows CI does not work with subprocesses for some reason
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
return
subprocess.run(
command,
shell=True,
check=True,
)
Loading