Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up DPO example #2043

Merged
merged 9 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class ScriptArguments:
)

# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default="wandb",
metadata={
Expand All @@ -89,7 +88,6 @@ class ScriptArguments:

def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: Optional[str] = None,
num_proc=24,
) -> Dataset:
Expand All @@ -114,9 +112,6 @@ def get_stack_exchange_paired(
)
original_columns = dataset.column_names

if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

def return_prompt_and_responses(samples) -> Dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
Expand Down Expand Up @@ -164,15 +159,15 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Stack-exchange paired dataset
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = get_stack_exchange_paired(data_dir="data/rl")
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
num_proc=script_args.num_proc,
)

# 3. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation")
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
Expand Down
67 changes: 31 additions & 36 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
# Full training
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--learning_rate 5.0e-7 \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
--output_dir Qwen2-0.5B-DPO \
--no_remove_unused_columns

# peft:
# LoRA:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--learning_rate 5.0e-6 \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
--output_dir Qwen2-0.5B-DPO \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
--lora_r 32 \
--lora_alpha 16
"""

from trl.commands.cli_utils import DPOScriptArguments, TrlParser
Expand All @@ -72,7 +69,7 @@

################
# Model & Tokenizer
################
###################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
Expand Down Expand Up @@ -114,25 +111,20 @@
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))

def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

################
##########
# Training
################
trainer = DPOTrainer(
Expand All @@ -146,4 +138,7 @@ def process(row):
)

trainer.train()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_model(training_args.output_dir)
5 changes: 0 additions & 5 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@
################
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))

def process(row):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
Expand Down
7 changes: 1 addition & 6 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--missing_eos_penalty 1.0 \
--stop_token eos \
--response_length 53 \
--sanity_check
--response_length 53

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
Expand Down Expand Up @@ -77,9 +76,6 @@
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
if config.sanity_check:
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1000))
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]

Expand All @@ -97,7 +93,6 @@ def tokenize(element):
return dataset.map(
tokenize,
remove_columns=dataset.column_names,
load_from_cache_file=not config.sanity_check,
num_proc=config.dataset_num_proc,
)

Expand Down
7 changes: 1 addition & 6 deletions examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--missing_eos_penalty 1.0 \
--stop_token eos \
--response_length 53 \
--sanity_check
--response_length 53

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/rloo/rloo_tldr.py \
Expand Down Expand Up @@ -77,9 +76,6 @@
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
if config.sanity_check:
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1000))
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]

Expand All @@ -97,7 +93,6 @@ def tokenize(element):
return dataset.map(
tokenize,
remove_columns=dataset.column_names,
load_from_cache_file=not config.sanity_check,
num_proc=config.dataset_num_proc,
)

Expand Down
16 changes: 1 addition & 15 deletions examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
--push_to_hub
"""

from dataclasses import dataclass, field

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
Expand All @@ -49,16 +47,8 @@
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


@dataclass
class ExtendedDPOScriptArguments(DPOScriptArguments):
max_samples: int = field(
default=None,
metadata={"help": "Maximum number of samples to use for training and evaluation. Use for sanity checking."},
)


if __name__ == "__main__":
parser = TrlParser((ExtendedDPOScriptArguments, XPOConfig, ModelConfig))
parser = TrlParser((DPOScriptArguments, XPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
args.gradient_checkpointing_kwargs = {"use_reentrant": True}

Expand Down Expand Up @@ -105,10 +95,6 @@ def prepare_dataset(row):
with PartialState().local_main_process_first():
dataset = dataset.map(prepare_dataset, num_proc=training_args.dataset_num_proc)

if args.max_samples is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @kashif @qgallouedec @edbeeching we should not add this logic into the example scripts IMO - it's best solved by adding support for something like the dataset mixer we have in the handbook or H4 repo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes my bad!

for split in dataset:
dataset[split] = dataset[split].select(range(min(args.max_samples, len(dataset[split]))))

prompts = dataset[args.dataset_test_split]["prompt"][:8]

trainer = XPOTrainer(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_sft_cli():
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-lib/ultrafeedback_binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
Expand Down
1 change: 0 additions & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ class DPOScriptArguments:
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to use for training"})
dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to use for evaluation"})
sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type of debugging arg shouldn't live in the lib IMO

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove them all

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Done in ddf30cb

ignore_bias_buffers: bool = field(
default=False,
metadata={
Expand Down
12 changes: 4 additions & 8 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ def _prepare_non_packed_dataloader(
remove_unused_columns=True,
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
Expand All @@ -557,13 +556,10 @@ def tokenize(element):
return_length=False,
)

if use_formatting_func and not self._dataset_sanity_checked:
if not isinstance(formatting_func(element), list):
raise ValueError(
"The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
)
else:
self._dataset_sanity_checked = True
if use_formatting_func and not isinstance(formatting_func(element), list):
raise ValueError(
"The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
)

return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}

Expand Down
3 changes: 0 additions & 3 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,6 @@ class OnPolicyConfig(TrainingArguments):
Parameters:
run_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the run.
sanity_check (`bool`, *optional*, defaults to `False`):
Whether to run in debug mode.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
num_mini_batches (`int`, *optional*, defaults to `1`):
Expand Down Expand Up @@ -946,7 +944,6 @@ class OnPolicyConfig(TrainingArguments):
"""

run_name: Optional[str] = None
sanity_check: bool = False
dataset_num_proc: Optional[int] = None
num_mini_batches: int = 1
total_episodes: Optional[int] = None
Expand Down
Loading