Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Mixing different LoRA adapters in same batch #1558

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,66 @@ model.unload()
# delete adapter
model.delete_adapter("dpo")
```

## Inference with different LoRA adapters in the same batch

Normally, each inference batch has to use the same adapter(s) in PEFT. This can sometimes be annoying, because we may have batches that contain samples intended to be used with different LoRA adapters. For example, we could have a base model that works well in English and two more LoRA adapters, one for French and one for German. Usually, we would have to split our batches such that each batch only contains samples of one of the languages, we cannot combine different languages in the same batch.

Thankfully, it is possible to mix different LoRA adapters in the same batch using the `adapter_name` argument. Below, we show an examle of how this works in practice. First, let's load the base model, English, and the two adapters, French and German, like this:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examle -> example


```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

model_id = ...
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)
# load the LoRA adapter for French
peft_model = PeftModel.from_pretrained(model, <path>, adapter_name="adapter_fr")
# next, load the LoRA adapter for German
peft_model.load_adapter(<path>, adapter_name="adapter_de")
```

Now, we want to generate text on a sample that contains all three languages: The first three samples are in English, the next three are in French, and the last three are in German. We can use the `adapter_names` argument to specify which adapter to use for each sample. Since our base model is used for English, we use the special string `"__base__"` for these samples. For the next three samples, we indicate the adapter name of the French LoRA fine-tune, in this case `"adapter_fr"`. For the last three samples, we indicate the adapter name of the German LoRA fine-tune, in this case `"adapter_de"`. This way, we can use the base model and the two adapters in a single batch.
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

```python
inputs = tokenizer(
[
"Hello, my dog is cute",
"Hello, my cat is awesome",
"Hello, my fish is great",
"Salut, mon chien est mignon",
"Salut, mon chat est génial",
"Salut, mon poisson est super",
"Hallo, mein Hund ist süß",
"Hallo, meine Katze ist toll",
"Hallo, mein Fisch ist großartig",
],
return_tensors="pt",
padding=True,
)

adapter_names = [
"__base__", "__base__", "__base__",
"adapter_fr", "adapter_fr", "adapter_fr",
"adapter_de", "adapter_de", "adapter_de",
]
output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_tokens=20)
```

Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples.

### Caveats

Using this features has some drawbacks, namely:

- It only works for inference, not for training.
- Disabling adapters using the `with model.disable_adapter` context takes precedence over `adapter_names`.
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
- You cannot pass `adapter_names` when some adapter weights where merged with base weight using the `merge_adapter` method. Please unmerge all adapters first by calling `model.unmerge_adapter()`.
- For obvious reasons, this cannot be used after calling `merge_and_unload`, since all the LoRA adapters will be merged into the base weights in this case.
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
- This feature does not currently work with DoRA, so set `use_dora=False` in your `LoraConfig` if you want to use it.
- There is an expected overhead for inference with `adapter_names`, especially if the amount of different adapters in the batch is high. This is because the batch size is effectively reduced to the number of samples per adapter. If runtime performance is your top priority, try the following:
- Increase the batch size.
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan on having some benchmarking of this feature? (may not even need vs other methods, but at least what throughput per batch size looks like on a transformer with different LoRAs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Noezor, as per the previous PR below is a simple benchmarking experiment.
Screenshot 2024-03-18 at 4 27 23 PM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the reference.

194 changes: 124 additions & 70 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,11 +537,33 @@ def __getattr__(self, name: str):
except AttributeError:
return getattr(self.base_model, name)

@contextmanager
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this
# runs without any changes
if hasattr(self.base_model, "_enable_peft_forward_hooks"):
with self.base_model._enable_peft_forward_hooks(*args, **kwargs):
yield
return
else:
# nothing to enable
yield
return

def forward(self, *args: Any, **kwargs: Any):
"""
Forward pass of the model.
"""
return self.get_base_model()(*args, **kwargs)
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make that as a class attribute ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
return self.get_base_model()(*args, **kwargs)

def generate(self, *args, **kwargs):
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
return self.get_base_model().generate(*args, **kwargs)

def _get_base_model_class(self, is_prompt_tuning=False):
"""
Expand Down Expand Up @@ -903,19 +925,22 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
peft_config = self.active_peft_config
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1095,16 +1120,20 @@ def forward(

if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1146,13 +1175,20 @@ def forward(
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

def generate(self, *args, **kwargs):
peft_config = self.active_peft_config
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
if hasattr(self.base_model, "model"):
self.base_model.model.generation_config = self.generation_config
else:
self.base_model.generation_config = self.generation_config
try:
outputs = self.base_model.generate(*args, **kwargs)
if not peft_config.is_prompt_learning:
special_peft_args = {"adapter_names"}
with self._enable_peft_forward_hooks(*args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
outputs = self.base_model.generate(*args, **kwargs)
else:
outputs = self.base_model.generate(**kwargs)
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
raise
Expand Down Expand Up @@ -1283,19 +1319,23 @@ def forward(
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if decoder_attention_mask is not None:
Expand Down Expand Up @@ -1396,7 +1436,10 @@ def generate(self, **kwargs):
)
try:
if not peft_config.is_prompt_learning:
outputs = self.base_model.generate(**kwargs)
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
Expand Down Expand Up @@ -1541,18 +1584,21 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1719,17 +1765,21 @@ def forward(
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down Expand Up @@ -1893,15 +1943,19 @@ def forward(
if not peft_config.is_prompt_learning:
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
Expand Down
Loading
Loading