diff --git a/setup.py b/setup.py index b795257029..064d324674 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ "llm-blender>=0.0.2", ], "peft": ["peft>=0.8.0"], + "liger": ["liger-kernel-nightly==0.1.1.dev20240828063745"], "diffusers": ["diffusers>=0.18.0"], "deepspeed": ["deepspeed>=0.14.4"], "benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"], diff --git a/trl/import_utils.py b/trl/import_utils.py index b0810d12ef..4530e9f65c 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -30,6 +30,10 @@ def is_peft_available() -> bool: return find_spec("peft") is not None +def is_liger_available() -> bool: + return find_spec("liger_kernel") is not None + + def is_unsloth_available() -> bool: return find_spec("unsloth") is not None diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 64e2b405a3..132a0c69d9 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -51,6 +51,8 @@ class SFTConfig(TrainingArguments): The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the stack-llama example: [chars_token_ratio](https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53). + use_liger (`Optional[bool]`): + Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. """ dataset_text_field: Optional[str] = None @@ -64,3 +66,4 @@ class SFTConfig(TrainingArguments): eval_packing: Optional[bool] = None num_of_sequences: Optional[int] = 1024 chars_per_token: Optional[float] = 3.6 + use_liger: Optional[bool] = False diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 054d20aeb5..b06e463c64 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -39,7 +39,7 @@ from transformers.trainer_utils import EvalPrediction from ..extras.dataset_formatting import get_formatting_func_from_dataset -from ..import_utils import is_peft_available +from ..import_utils import is_liger_available, is_peft_available from .callbacks import RichProgressCallback from .sft_config import SFTConfig from .utils import ( @@ -54,6 +54,9 @@ if is_peft_available(): from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training +if is_liger_available(): + from liger_kernel.transformers import AutoLigerKernelForCausalLM + class SFTTrainer(Trainer): r""" @@ -183,7 +186,10 @@ def __init__( "You passed a model_id to the SFTTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." ) - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + if args.use_liger: + model = AutoLigerKernelForCausalLM.from_pretrained(model, **model_init_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) if packing: warnings.warn(