Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add CLIs in TRL ! #1419

Merged
merged 64 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
f0d29ce
CLI V1
younesbelkada Mar 11, 2024
3a2283a
v1 CLI
younesbelkada Mar 12, 2024
15d166e
add rich enhancmeents
younesbelkada Mar 12, 2024
16463b8
revert unindented change
younesbelkada Mar 12, 2024
d20167f
some comments
younesbelkada Mar 12, 2024
38ee375
cleaner CLI
younesbelkada Mar 12, 2024
f83882c
fix
younesbelkada Mar 12, 2024
14911d2
fix
younesbelkada Mar 12, 2024
b7f96bc
remove print callback
younesbelkada Mar 12, 2024
a328d9b
move to cli instead of trl_cli
younesbelkada Mar 12, 2024
4ee1c8e
revert unneeded changes
younesbelkada Mar 12, 2024
55659ce
fix test
younesbelkada Mar 12, 2024
459c3eb
Update trl/commands/sft.py
younesbelkada Mar 12, 2024
171fd94
remove redundant strings
younesbelkada Mar 12, 2024
54662da
Merge branch 'add-cli' of https://github.com/lvwerra/trl into add-cli
younesbelkada Mar 12, 2024
fbec5ca
fix import issue
younesbelkada Mar 12, 2024
616ee60
fix other issues
younesbelkada Mar 12, 2024
553f898
add packing
younesbelkada Mar 12, 2024
d098262
add config parser
younesbelkada Mar 12, 2024
6110423
some refactor
younesbelkada Mar 12, 2024
e9d4f91
cleaner
younesbelkada Mar 12, 2024
265b488
add example config yaml file
younesbelkada Mar 12, 2024
355e57c
small refactor
younesbelkada Mar 13, 2024
8df993a
change a bit the logic
younesbelkada Mar 13, 2024
8cf0b05
Merge remote-tracking branch 'origin/main' into add-cli
younesbelkada Mar 13, 2024
f64618f
fix issues here and there
younesbelkada Mar 13, 2024
c33dead
add CLI in docs
younesbelkada Mar 13, 2024
0e45168
move to examples/sft
younesbelkada Mar 13, 2024
3119da6
remove redundant licenses
younesbelkada Mar 13, 2024
4d0da9b
make it work on dpo
younesbelkada Mar 13, 2024
cf2290f
set to None
younesbelkada Mar 13, 2024
bac1780
switch to accelerate and fix many things
younesbelkada Mar 14, 2024
086b37c
add docs
younesbelkada Mar 14, 2024
d49e5e8
more docs
younesbelkada Mar 14, 2024
c7f4c83
added tests
younesbelkada Mar 14, 2024
90526bf
doc clarification
younesbelkada Mar 14, 2024
c91513a
Merge remote-tracking branch 'origin/main' into add-cli
younesbelkada Mar 14, 2024
3e0e3c9
Merge remote-tracking branch 'origin/main' into add-cli
younesbelkada Mar 14, 2024
c260432
more docs
younesbelkada Mar 14, 2024
2ba16f3
fix CI for windows and python 3.8
younesbelkada Mar 14, 2024
a9d68b5
fix
younesbelkada Mar 14, 2024
755db1e
attempt to fix CI
younesbelkada Mar 15, 2024
61ee67b
fix?
younesbelkada Mar 15, 2024
89b594e
test
younesbelkada Mar 15, 2024
d93a8e1
Merge branch 'add-cli' of https://github.com/lvwerra/trl into add-cli
younesbelkada Mar 15, 2024
d5ab9d6
fix
younesbelkada Mar 15, 2024
026ceef
tweak?
younesbelkada Mar 15, 2024
be1ec61
fix
younesbelkada Mar 15, 2024
8600269
test
younesbelkada Mar 15, 2024
ac99f35
another test
younesbelkada Mar 15, 2024
7252ad0
fix
younesbelkada Mar 15, 2024
55eda92
test
younesbelkada Mar 15, 2024
76dbe94
fix
younesbelkada Mar 15, 2024
e6678f3
fix
younesbelkada Mar 15, 2024
45424b8
fix
younesbelkada Mar 15, 2024
2184b05
skip tests for windows
younesbelkada Mar 15, 2024
79a4074
test @lvwerra approach
younesbelkada Mar 15, 2024
c92dd3b
make dev
younesbelkada Mar 15, 2024
a477236
revert unneeded changes
younesbelkada Mar 15, 2024
a1d228f
fix sft dpo
younesbelkada Mar 15, 2024
ef144d0
optimize a bit
younesbelkada Mar 15, 2024
c85b8e4
address final comments
younesbelkada Mar 18, 2024
91a55ca
update docs
younesbelkada Mar 18, 2024
7754760
final comment
younesbelkada Mar 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Before you start contributing make sure you installed all the dev tools:

```bash
pip install -e ".[dev]"
make dev
```

## Did you find a bug?
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
recursive-exclude * __pycache__
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands


dev:
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
pip install -e ".[dev]"

test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/

Expand Down
2 changes: 2 additions & 0 deletions commands/run_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/Anthropic-hh-rlhf-processed"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
Expand Down Expand Up @@ -36,6 +37,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
title: Quickstart
- local: installation
title: Installation
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: how_to_train
title: PPO Training FAQ
- local: use_model
Expand Down
87 changes: 87 additions & 0 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Command Line Interfaces (CLIs)

You can use TRL to fine-tune your Language Model on Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the TRL CLIs.

Currently supported CLIs are:

- `trl sft`
- `trl dpo`

## Get started

Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.

Also make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.

We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.

```yaml
model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```

Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.:

```bash
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```

Will force-use `cosine_with_restarts` for `lr_scheduler_type`.

## Supported Arguments

We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:

[[autodoc]] ModelConfig

You can pass any of these arguments either to the CLI or the YAML file.

### Supervised Fine-tuning (SFT)

Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:

```bash
trl sft --config config.yaml --output_dir your-output-dir
```

The SFT CLI is based on the `examples/scripts/sft.py` script.

### Direct Policy Optimization (DPO)

First, follow the basic instructions above and run `trl dpo --output_dir <output_dir> <*args>`. Make sure to process your DPO dataset in the TRL format as follows:

1- Make sure to pre-tokenize the dataset using chat templates:

```bash
python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset
```

You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template

2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):

```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```

Once your dataset being pushed, run the dpo CLI as follows:

```bash
trl dpo --config config.yaml --output_dir your-output-dir
```

The SFT CLI is based on the `examples/scripts/dpo.py` script.
20 changes: 20 additions & 0 deletions example_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This is an example configuration file of TRL CLI, you can use it for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about adding it in the cli or examples folder folder?

# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
# The YAML file supports environment variables by adding an `env` field
# as below

# env:
# CUDA_VISIBLE_DEVICES: 0

model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
1e-4
lr_scheduler_type:
cosine
139 changes: 63 additions & 76 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -48,76 +49,47 @@
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Dict, Optional
import logging
import os
from contextlib import nullcontext

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config


@dataclass
class ScriptArguments:
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser

def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"

from rich.console import Console
from rich.logging import RichHandler

def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
from trl import (
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)

def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}

return dataset.map(split_prompt_and_responses)
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()

################
# Model & Tokenizer
Expand Down Expand Up @@ -152,28 +124,43 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)

################
# Dataset
################
train_dataset = get_hh("train", sanity_check=args.sanity_check)
eval_dataset = get_hh("test", sanity_check=args.sanity_check)
train_dataset = load_dataset(args.dataset_name, split="train")
eval_dataset = load_dataset(args.dataset_name, split="test")

################
# Training
################
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
)
with init_context:
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)

trainer.train()
trainer.save_model(training_args.output_dir)

with save_context:
trainer.save_model(training_args.output_dir)
Loading
Loading