Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Mixing different LoRA adapters in same batch #1558

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,66 @@ model.unload()
# delete adapter
model.delete_adapter("dpo")
```

## Inference with different LoRA adapters in the same batch

Normally, each inference batch has to use the same adapter(s) in PEFT. This can sometimes be annoying, because we may have batches that contain samples intended to be used with different LoRA adapters. For example, we could have a base model that works well in English and two more LoRA adapters, one for French and one for German. Usually, we would have to split our batches such that each batch only contains samples of one of the languages, we cannot combine different languages in the same batch.

Thankfully, it is possible to mix different LoRA adapters in the same batch using the `adapter_name` argument. Below, we show an examle of how this works in practice. First, let's load the base model, English, and the two adapters, French and German, like this:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examle -> example


```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

model_id = ...
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)
# load the LoRA adapter for French
peft_model = PeftModel.from_pretrained(model, <path>, adapter_name="adapter_fr")
# next, load the LoRA adapter for German
peft_model.load_adapter(<path>, adapter_name="adapter_de")
```

Now, we want to generate text on a sample that contains all three languages: The first three samples are in English, the next three are in French, and the last three are in German. We can use the `adapter_names` argument to specify which adapter to use for each sample. Since our base model is used for English, we use the special string `"__base__"` for these samples. For the next three samples, we indicate the adapter name of the French LoRA fine-tune, in this case `"adapter_fr"`. For the last three samples, we indicate the adapter name of the German LoRA fine-tune, in this case `"adapter_de"`. This way, we can use the base model and the two adapters in a single batch.
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

```python
inputs = tokenizer(
[
"Hello, my dog is cute",
"Hello, my cat is awesome",
"Hello, my fish is great",
"Salut, mon chien est mignon",
"Salut, mon chat est génial",
"Salut, mon poisson est super",
"Hallo, mein Hund ist süß",
"Hallo, meine Katze ist toll",
"Hallo, mein Fisch ist großartig",
],
return_tensors="pt",
padding=True,
)

adapter_names = [
"__base__", "__base__", "__base__",
"adapter_fr", "adapter_fr", "adapter_fr",
"adapter_de", "adapter_de", "adapter_de",
]
output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_tokens=20)
```

Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples.

### Caveats

Using this features has some drawbacks, namely:

- It only works for inference, not for training.
- Disabling adapters using the `with model.disable_adapter()` context takes precedence over `adapter_names`.
- You cannot pass `adapter_names` when some adapter weights where merged with base weight using the `merge_adapter` method. Please unmerge all adapters first by calling `model.unmerge_adapter()`.
- For obvious reasons, this cannot be used after calling `merge_and_unload()`, since all the LoRA adapters will be merged into the base weights in this case.
- This feature does not currently work with DoRA, so set `use_dora=False` in your `LoraConfig` if you want to use it.
- There is an expected overhead for inference with `adapter_names`, especially if the amount of different adapters in the batch is high. This is because the batch size is effectively reduced to the number of samples per adapter. If runtime performance is your top priority, try the following:
- Increase the batch size.
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
- Take a look at alternative implementations such as [LoRAX](https://github.com/predibase/lorax), [punica](https://github.com/punica-ai/punica), or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
189 changes: 119 additions & 70 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name
self.modules_to_save = None
self.active_adapter = adapter_name
self.peft_type = peft_config.peft_type
# These args are special PEFT arguments that users can pass. They need to be removed before passing them to
# forward.
self.special_peft_forward_args = {"adapter_names"}

self._is_prompt_learning = peft_config.is_prompt_learning
if self._is_prompt_learning:
Expand Down Expand Up @@ -537,11 +540,31 @@ def __getattr__(self, name: str):
except AttributeError:
return getattr(self.base_model, name)

@contextmanager
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this
# runs without any changes
if hasattr(self.base_model, "_enable_peft_forward_hooks"):
with self.base_model._enable_peft_forward_hooks(*args, **kwargs):
yield
return
else:
# nothing to enable
yield
return

def forward(self, *args: Any, **kwargs: Any):
"""
Forward pass of the model.
"""
return self.get_base_model()(*args, **kwargs)
with self._enable_peft_forward_hooks(*args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.get_base_model()(*args, **kwargs)

def generate(self, *args, **kwargs):
with self._enable_peft_forward_hooks(*args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.get_base_model().generate(*args, **kwargs)

def _get_base_model_class(self, is_prompt_tuning=False):
"""
Expand Down Expand Up @@ -595,6 +618,8 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
"""
Add an adapter to the model based on the passed configuration.

This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].

The name for the new adapter should be unique.

The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
Expand Down Expand Up @@ -904,18 +929,20 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
peft_config = self.active_peft_config
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1095,16 +1122,19 @@ def forward(

if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1146,13 +1176,19 @@ def forward(
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

def generate(self, *args, **kwargs):
peft_config = self.active_peft_config
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
if hasattr(self.base_model, "model"):
self.base_model.model.generation_config = self.generation_config
else:
self.base_model.generation_config = self.generation_config
try:
outputs = self.base_model.generate(*args, **kwargs)
if not peft_config.is_prompt_learning:
with self._enable_peft_forward_hooks(*args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
outputs = self.base_model.generate(*args, **kwargs)
else:
outputs = self.base_model.generate(**kwargs)
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
raise
Expand Down Expand Up @@ -1283,19 +1319,22 @@ def forward(
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if decoder_attention_mask is not None:
Expand Down Expand Up @@ -1396,7 +1435,9 @@ def generate(self, **kwargs):
)
try:
if not peft_config.is_prompt_learning:
outputs = self.base_model.generate(**kwargs)
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
Expand Down Expand Up @@ -1541,18 +1582,20 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1719,17 +1762,20 @@ def forward(
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1893,15 +1939,18 @@ def forward(
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down
Loading
Loading