Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds experimental Liger support to SFT script #1992

Merged
merged 13 commits into from
Aug 29, 2024
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"llm-blender>=0.0.2",
],
"peft": ["peft>=0.8.0"],
"liger": ["liger-kernel-nightly==0.1.1.dev20240828063745"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a big fan of pinning this dev version so we should update to the real package ASAP·

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will cut a release in few hours

"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.14.4"],
"benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"],
Expand Down
4 changes: 4 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def is_peft_available() -> bool:
return find_spec("peft") is not None


def is_liger_available() -> bool:
return find_spec("liger_kernel") is not None


def is_unsloth_available() -> bool:
return find_spec("unsloth") is not None

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class SFTConfig(TrainingArguments):
The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
stack-llama example:
[chars_token_ratio](https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53).
use_liger (`Optional[bool]`):
Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
"""

dataset_text_field: Optional[str] = None
Expand All @@ -64,3 +66,4 @@ class SFTConfig(TrainingArguments):
eval_packing: Optional[bool] = None
num_of_sequences: Optional[int] = 1024
chars_per_token: Optional[float] = 3.6
use_liger: Optional[bool] = False
10 changes: 8 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from transformers.trainer_utils import EvalPrediction

from ..extras.dataset_formatting import get_formatting_func_from_dataset
from ..import_utils import is_peft_available
from ..import_utils import is_liger_available, is_peft_available
from .callbacks import RichProgressCallback
from .sft_config import SFTConfig
from .utils import (
Expand All @@ -54,6 +54,9 @@
if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_liger_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM


class SFTTrainer(Trainer):
r"""
Expand Down Expand Up @@ -183,7 +186,10 @@ def __init__(
"You passed a model_id to the SFTTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if args.use_liger:
model = AutoLigerKernelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

if packing:
warnings.warn(
Expand Down
Loading