Skip to content

Commit

Permalink
FIX Use SFTConfig instead of SFTTrainer keyword args (huggingface…
Browse files Browse the repository at this point in the history
…#2150)

Update training script using trl to fix deprecations in argument usage.
  • Loading branch information
qgallouedec authored Oct 15, 2024
1 parent c039b00 commit 93ddb10
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 51 deletions.
9 changes: 1 addition & 8 deletions docs/source/accelerate/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Notice that we are using LoRA with rank=8, alpha=16 and targeting all linear la
Let's dive a little deeper into the script so you can see what's going on, and understand how it works.
The first thing to know is that the script uses DeepSpeed for distributed training as the DeepSpeed config has been passed. The `SFTTrainer` class handles all the heavy lifting of creating the PEFT model using the peft config that is passed. After that, when you call `trainer.train()`, `SFTTrainer` internally uses 🤗 Accelerate to prepare the model, optimizer and trainer using the DeepSpeed config to create DeepSpeed engine which is then trained. The main code snippet is below:
The first thing to know is that the script uses DeepSpeed for distributed training as the DeepSpeed config has been passed. The [`~trl.SFTTrainer`] class handles all the heavy lifting of creating the PEFT model using the peft config that is passed. After that, when you call `trainer.train()`, [`~trl.SFTTrainer`] internally uses 🤗 Accelerate to prepare the model, optimizer and trainer using the DeepSpeed config to create DeepSpeed engine which is then trained. The main code snippet is below:
```python
# trainer
Expand All @@ -139,13 +139,6 @@ trainer = SFTTrainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
Expand Down
9 changes: 1 addition & 8 deletions docs/source/accelerate/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Notice that we are using LoRA with rank=8, alpha=16 and targeting all linear la

Let's dive a little deeper into the script so you can see what's going on, and understand how it works.

The first thing to know is that the script uses FSDP for distributed training as the FSDP config has been passed. The `SFTTrainer` class handles all the heavy lifting of creating PEFT model using the peft config that is passed. After that when you call `trainer.train()`, Trainer internally uses 🤗 Accelerate to prepare model, optimizer and trainer using the FSDP config to create FSDP wrapped model which is then trained. The main code snippet is below:
The first thing to know is that the script uses FSDP for distributed training as the FSDP config has been passed. The [`~trl.SFTTrainer`] class handles all the heavy lifting of creating PEFT model using the peft config that is passed. After that when you call `trainer.train()`, Trainer internally uses 🤗 Accelerate to prepare model, optimizer and trainer using the FSDP config to create FSDP wrapped model which is then trained. The main code snippet is below:

```python
# trainer
Expand All @@ -119,13 +119,6 @@ trainer = SFTTrainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
if model_args.use_peft_lora:
Expand Down
5 changes: 2 additions & 3 deletions examples/olora_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
Expand All @@ -18,11 +18,10 @@ lora_config = LoraConfig(
init_lora_weights="olora"
)
peft_model = get_peft_model(model, lora_config)
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
tokenizer=tokenizer,
)
trainer.train()
Expand Down
7 changes: 3 additions & 4 deletions examples/pissa_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ PiSSA represents a matrix $W\in\mathbb{R}^{m\times n}$ within the model by the p
```python
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from transformers import AutoTokenizer, AutoModelForCausalLMfrom trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
Expand All @@ -23,11 +22,11 @@ peft_model.print_trainable_parameters()

dataset = load_dataset("imdb", split="train[:1%]")

training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=128,
tokenizer=tokenizer,
)
trainer.train()
Expand Down
17 changes: 5 additions & 12 deletions examples/pissa_finetuning/pissa_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

import os
from dataclasses import dataclass, field
from typing import List, Optional
from typing import Optional

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser
from trl import SFTConfig, SFTTrainer

from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


@dataclass
class TrainingArguments(TrainingArguments):
class ScriptArguments(SFTConfig):
# model configs
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
Expand All @@ -46,14 +46,9 @@ class TrainingArguments(TrainingArguments):
# dataset configs
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
dataset_field: List[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
max_seq_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)


parser = HfArgumentParser(TrainingArguments)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
print(script_args)

Expand Down Expand Up @@ -133,8 +128,6 @@ class TrainingArguments(TrainingArguments):
model=peft_model,
args=script_args,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
)
trainer.train()
Expand Down
24 changes: 8 additions & 16 deletions examples/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from dataclasses import dataclass, field
from typing import Optional

from transformers import HfArgumentParser, TrainingArguments, set_seed
from trl import SFTTrainer
from transformers import HfArgumentParser, set_seed
from trl import SFTConfig, SFTTrainer
from utils import create_and_prepare_model, create_datasets


Expand Down Expand Up @@ -79,12 +79,6 @@ class DataTrainingArguments:
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
packing: Optional[bool] = field(
default=False,
metadata={"help": "Use packing dataset creating."},
)
dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
max_seq_length: Optional[int] = field(default=512)
append_concat_token: Optional[bool] = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
Expand Down Expand Up @@ -112,6 +106,11 @@ def main(model_args, data_args, training_args):
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}

training_args.dataset_kwargs = {
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
}

# datasets
train_dataset, eval_dataset = create_datasets(
tokenizer,
Expand All @@ -128,13 +127,6 @@ def main(model_args, data_args, training_args):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
if hasattr(trainer.model, "print_trainable_parameters"):
Expand All @@ -153,7 +145,7 @@ def main(model_args, data_args, training_args):


if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down

0 comments on commit 93ddb10

Please sign in to comment.