diff --git a/setup.py b/setup.py index 373cfeea06..765b65a66a 100644 --- a/setup.py +++ b/setup.py @@ -71,15 +71,15 @@ EXTRAS = { "test": [ "parameterized", + "peft>=0.8.0", "pytest", "pytest-xdist", - "accelerate", "pytest-cov", "pytest-xdist", "scikit-learn", "Pillow", ], - "peft": ["peft>=0.4.0"], + "peft": ["peft>=0.8.0"], "diffusers": ["diffusers>=0.18.0"], "deepspeed": ["deepspeed>=0.9.5"], "benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"], diff --git a/tests/test_utils.py b/tests/test_utils.py index 5e5c3ec9c9..d2885f1d2b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,15 @@ import torch -from trl.trainer.utils import pad +from trl import is_peft_available +from trl.trainer.model_config import ModelConfig +from trl.trainer.utils import get_peft_config, pad + + +if is_peft_available(): + from peft import LoraConfig + +from .testing_utils import require_peft class TestPad(unittest.TestCase): @@ -55,3 +63,37 @@ def test_pad_2_dim_right_multidim(self): ] ) self.assertTrue(torch.equal(output, expected)) + + +@require_peft +class TestGetPEFTConfig(unittest.TestCase): + def test_create_peft_config_use_peft_false(self): + """Test that when use_peft is False, the function returns None.""" + model_config = ModelConfig(use_peft=False) + peft_config = get_peft_config(model_config) + self.assertIsNone(peft_config) + + def test_create_peft_config_use_peft_true(self): + """Test that when use_peft is True, the function returns a LoraConfig object.""" + # Provide non-default values to the model config for testing + peft_kwargs = { + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.1, + "lora_task_type": "SEQ_CLS", + "use_rslora": True, + "lora_target_modules": ["up_proj", "down_proj"], + "lora_modules_to_save": ["up_proj"], + } + model_config = ModelConfig(use_peft=True, **peft_kwargs) + peft_config = get_peft_config(model_config) + self.assertTrue(isinstance(peft_config, LoraConfig)) + for arg, value in peft_kwargs.items(): + # Test that lists of modules are converted to sets + if arg == "lora_target_modules": + value = set(value) + # Rename the argument to match the LoraConfig attribute name + if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]: + arg = arg[len("lora_") :] if arg.startswith("lora_") else arg + + self.assertEqual(getattr(peft_config, arg), value) diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index b16a07421d..bc0caf93d2 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -64,6 +64,15 @@ class ModelConfig: lora_task_type: str = field( default="CAUSAL_LM", metadata={"help": "The task_type to pass for LoRA (use SEQ_CLS for reward modeling)"} ) + use_rslora: bool = field( + default=False, + metadata={ + "help": ( + "Use Rank-Stabilized LoRA (https://huggingface.co/papers/2312.03732), which sets the adapter " + "scaling factor to lora_alpha/√r, instead of the original default value of `lora_alpha/r`." + ) + }, + ) load_in_8bit: bool = field( default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"} ) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 12744c6370..0fb88fbaca 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -811,12 +811,13 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": ) peft_config = LoraConfig( + task_type=model_config.lora_task_type, r=model_config.lora_r, + target_modules=model_config.lora_target_modules, lora_alpha=model_config.lora_alpha, lora_dropout=model_config.lora_dropout, bias="none", - task_type=model_config.lora_task_type, - target_modules=model_config.lora_target_modules, + use_rslora=model_config.use_rslora, modules_to_save=model_config.lora_modules_to_save, )