diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index 33e26b8984..e4b6ec28e0 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -369,6 +369,8 @@ output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_toke Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples. +Additionally, the same approach also works with the `modules_to_save` feature, which allows for saving and reusing specific neural network layers, such as custom heads for classification tasks, across different LoRA adapters. + ### Caveats Using this features has some drawbacks, namely: @@ -378,6 +380,7 @@ Using this features has some drawbacks, namely: - You cannot pass `adapter_names` when some adapter weights where merged with base weight using the `merge_adapter` method. Please unmerge all adapters first by calling `model.unmerge_adapter()`. - For obvious reasons, this cannot be used after calling `merge_and_unload()`, since all the LoRA adapters will be merged into the base weights in this case. - This feature does not currently work with DoRA, so set `use_dora=False` in your `LoraConfig` if you want to use it. +- The `modules_to_save` feature is currently only supported for the layers of types `Linear`, `Embedding`, `Conv2d` and `Conv1d`. - There is an expected overhead for inference with `adapter_names`, especially if the amount of different adapters in the batch is high. This is because the batch size is effectively reduced to the number of samples per adapter. If runtime performance is your top priority, try the following: - Increase the batch size. - Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters. diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index bea131bf91..107a2593c5 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -432,7 +432,7 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): hook_handles = [] for module in self.modules(): - if isinstance(module, LoraLayer): + if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper): pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e94e086dce..4238988576 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import inspect import os import warnings from contextlib import nullcontext -from typing import Optional, Tuple +from typing import Any, Optional import accelerate import torch @@ -268,10 +270,62 @@ def _create_new_hook(self, old_hook): new_hook = old_hook_cls(**filtered_old_hook_attr) return new_hook - def forward(self, *args, **kwargs): + def _check_forward_args(self, x, *args, **kwargs): + """Check if the arguments are compatible with the configs and state of the model""" + adapter_names = kwargs.get("adapter_names", None) + if adapter_names is None: + return + + if len(x) != len(adapter_names): + msg = ( + "Length of `adapter_names` should be the same as the number of inputs, but got " + f"{len(adapter_names)} and {len(x)} respectively." + ) + raise ValueError(msg) + + def _mixed_batch_forward( + self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + + SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d) + + module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES]) + + if not isinstance(self.original_module, SUPPORTED_MODULES): + raise TypeError(f"Mixed batching is only supported for the following modules: {module_names}.") + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + results = [0 for _ in range(len(input))] + + for i, active_adapter in enumerate(unique_adapters): + sub_batch = input[sub_batch_indices_list[i]] + + if active_adapter == "__base__": + output = self.original_module(sub_batch, *args, **kwargs) + else: + output = self.modules_to_save[active_adapter](sub_batch, *args, **kwargs) + + for index, j in enumerate(sub_batch_indices_list[i]): + results[j] = output[index] + + return torch.stack(results) + + def forward(self, x: torch.Tensor, *args, **kwargs): + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + if self.disable_adapters or (self.active_adapter not in self.modules_to_save): - return self.original_module(*args, **kwargs) - return self.modules_to_save[self.active_adapter](*args, **kwargs) + return self.original_module(x, *args, **kwargs) + if adapter_names is None: + return self.modules_to_save[self.active_adapter](x, *args, **kwargs) + return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) def enable_adapters(self, enabled: bool): """Toggle the enabling and disabling of adapters @@ -546,7 +600,7 @@ def get_auto_gptq_quant_linear(gptq_quantization_config): return None -def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: +def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: """ Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 4e13d1e46a..100b2f4312 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -583,6 +583,29 @@ def forward(self, X): return X +class MLPWithGRU(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.relu = nn.ReLU() + self.drop = nn.Dropout(0.5) + self.gru = nn.GRU(input_size=20, hidden_size=20, num_layers=1, batch_first=True, bias=bias) + self.fc = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.drop(X) + X = X.unsqueeze(1) + X, _ = self.gru(X) + X = X.squeeze(1) + X = self.fc(X) + X = self.sm(X) + return X + + class MLP_LayerNorm(nn.Module): def __init__(self, bias=True): super().__init__() @@ -3326,15 +3349,36 @@ def test_mixed_adapter_batches_lora_mlp(self, mlp_lora): def test_mixed_adapter_batches_lora_different_target_layers(self, mlp_lora): base_model = MLP().to(self.torch_device).eval() - # target different lora layers config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False) config1 = LoraConfig(target_modules=["lin1"], init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) + inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} + self.run_checks(peft_model, inputs) + def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, mlp_lora): + base_model = MLP().to(self.torch_device).eval() + config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False) + config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False) + peft_model = get_peft_model(base_model, config0, "adapter0").eval() + peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) + def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, mlp_lora): + base_model = MLPWithGRU().to(self.torch_device).eval() + config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False) + config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False) + peft_model = get_peft_model(base_model, config0, "adapter0").eval() + peft_model.add_adapter("adapter1", config1) + inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} + SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d) + module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES]) + with pytest.raises( + TypeError, match=f"Mixed batching is only supported for the following modules: {module_names}." + ): + self.run_checks(peft_model, inputs) + def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, mlp_lora): base_model = MLP().to(self.torch_device).eval() # target different lora layers @@ -3356,6 +3400,15 @@ def test_mixed_adapter_batches_lora_conv1d_emb(self): inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) + def test_mixed_adapter_batches_lora_conv1d_emb_multiple_modules_to_save(self): + base_model = ModelEmbConv1D().to(self.torch_device).eval() + config0 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False) + config1 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False) + peft_model = get_peft_model(base_model, config0, "adapter0").eval() + peft_model.add_adapter("adapter1", config1) + inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} + self.run_checks(peft_model, inputs) + def test_mixed_adapter_batches_lora_conv2d(self): base_model = ModelConv2D().to(self.torch_device).eval() config0 = LoraConfig(target_modules=["conv2d"], init_lora_weights=False)