-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Mixing different LoRA adapters in same batch #1558
Changes from 3 commits
3e85312
d3dafcf
aec22e4
195590b
4812ea4
21905c9
38640ff
46a7aac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,3 +219,66 @@ model.unload() | |
# delete adapter | ||
model.delete_adapter("dpo") | ||
``` | ||
|
||
## Inference with different LoRA adapters in the same batch | ||
|
||
Normally, each inference batch has to use the same adapter(s) in PEFT. This can sometimes be annoying, because we may have batches that contain samples intended to be used with different LoRA adapters. For example, we could have a base model that works well in English and two more LoRA adapters, one for French and one for German. Usually, we would have to split our batches such that each batch only contains samples of one of the languages, we cannot combine different languages in the same batch. | ||
|
||
Thankfully, it is possible to mix different LoRA adapters in the same batch using the `adapter_name` argument. Below, we show an examle of how this works in practice. First, let's load the base model, English, and the two adapters, French and German, like this: | ||
|
||
```python | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from peft import PeftModel | ||
|
||
model_id = ... | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained(model_id) | ||
# load the LoRA adapter for French | ||
peft_model = PeftModel.from_pretrained(model, <path>, adapter_name="adapter_fr") | ||
# next, load the LoRA adapter for German | ||
peft_model.load_adapter(<path>, adapter_name="adapter_de") | ||
``` | ||
|
||
Now, we want to generate text on a sample that contains all three languages: The first three samples are in English, the next three are in French, and the last three are in German. We can use the `adapter_names` argument to specify which adapter to use for each sample. Since our base model is used for English, we use the special string `"__base__"` for these samples. For the next three samples, we indicate the adapter name of the French LoRA fine-tune, in this case `"adapter_fr"`. For the last three samples, we indicate the adapter name of the German LoRA fine-tune, in this case `"adapter_de"`. This way, we can use the base model and the two adapters in a single batch. | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
inputs = tokenizer( | ||
[ | ||
"Hello, my dog is cute", | ||
"Hello, my cat is awesome", | ||
"Hello, my fish is great", | ||
"Salut, mon chien est mignon", | ||
"Salut, mon chat est génial", | ||
"Salut, mon poisson est super", | ||
"Hallo, mein Hund ist süß", | ||
"Hallo, meine Katze ist toll", | ||
"Hallo, mein Fisch ist großartig", | ||
], | ||
return_tensors="pt", | ||
padding=True, | ||
) | ||
|
||
adapter_names = [ | ||
"__base__", "__base__", "__base__", | ||
"adapter_fr", "adapter_fr", "adapter_fr", | ||
"adapter_de", "adapter_de", "adapter_de", | ||
] | ||
output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_tokens=20) | ||
``` | ||
|
||
Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples. | ||
|
||
### Caveats | ||
|
||
Using this features has some drawbacks, namely: | ||
|
||
- It only works for inference, not for training. | ||
- Disabling adapters using the `with model.disable_adapter` context takes precedence over `adapter_names`. | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- You cannot pass `adapter_names` when some adapter weights where merged with base weight using the `merge_adapter` method. Please unmerge all adapters first by calling `model.unmerge_adapter()`. | ||
- For obvious reasons, this cannot be used after calling `merge_and_unload`, since all the LoRA adapters will be merged into the base weights in this case. | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- This feature does not currently work with DoRA, so set `use_dora=False` in your `LoraConfig` if you want to use it. | ||
- There is an expected overhead for inference with `adapter_names`, especially if the amount of different adapters in the batch is high. This is because the batch size is effectively reduced to the number of samples per adapter. If runtime performance is your top priority, try the following: | ||
- Increase the batch size. | ||
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters. | ||
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you plan on having some benchmarking of this feature? (may not even need vs other methods, but at least what throughput per batch size looks like on a transformer with different LoRAs) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also worth mentioning https://github.com/predibase/lorax There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Noezor, as per the previous PR below is a simple benchmarking experiment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the reference. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -537,11 +537,33 @@ def __getattr__(self, name: str): | |
except AttributeError: | ||
return getattr(self.base_model, name) | ||
|
||
@contextmanager | ||
def _enable_peft_forward_hooks(self, *args, **kwargs): | ||
# If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this | ||
# runs without any changes | ||
if hasattr(self.base_model, "_enable_peft_forward_hooks"): | ||
with self.base_model._enable_peft_forward_hooks(*args, **kwargs): | ||
yield | ||
return | ||
else: | ||
# nothing to enable | ||
yield | ||
return | ||
|
||
def forward(self, *args: Any, **kwargs: Any): | ||
""" | ||
Forward pass of the model. | ||
""" | ||
return self.get_base_model()(*args, **kwargs) | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
special_peft_args = {"adapter_names"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make that as a class attribute ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
return self.get_base_model()(*args, **kwargs) | ||
|
||
def generate(self, *args, **kwargs): | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
special_peft_args = {"adapter_names"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
return self.get_base_model().generate(*args, **kwargs) | ||
|
||
def _get_base_model_class(self, is_prompt_tuning=False): | ||
""" | ||
|
@@ -903,19 +925,22 @@ def forward( | |
): | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
peft_config = self.active_peft_config | ||
special_peft_args = {"adapter_names"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if not peft_config.is_prompt_learning: | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
with self._enable_peft_forward_hooks(**kwargs): | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
batch_size = _get_batch_size(input_ids, inputs_embeds) | ||
if attention_mask is not None: | ||
|
@@ -1095,16 +1120,20 @@ def forward( | |
|
||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
with self._enable_peft_forward_hooks(**kwargs): | ||
special_peft_args = {"adapter_names"} | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
batch_size = _get_batch_size(input_ids, inputs_embeds) | ||
if attention_mask is not None: | ||
|
@@ -1146,13 +1175,20 @@ def forward( | |
return self.base_model(inputs_embeds=inputs_embeds, **kwargs) | ||
|
||
def generate(self, *args, **kwargs): | ||
peft_config = self.active_peft_config | ||
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation | ||
if hasattr(self.base_model, "model"): | ||
self.base_model.model.generation_config = self.generation_config | ||
else: | ||
self.base_model.generation_config = self.generation_config | ||
try: | ||
outputs = self.base_model.generate(*args, **kwargs) | ||
if not peft_config.is_prompt_learning: | ||
special_peft_args = {"adapter_names"} | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
outputs = self.base_model.generate(*args, **kwargs) | ||
else: | ||
outputs = self.base_model.generate(**kwargs) | ||
except: | ||
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation | ||
raise | ||
|
@@ -1283,19 +1319,23 @@ def forward( | |
if not peft_config.is_prompt_learning: | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
decoder_input_ids=decoder_input_ids, | ||
decoder_attention_mask=decoder_attention_mask, | ||
decoder_inputs_embeds=decoder_inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
with self._enable_peft_forward_hooks(**kwargs): | ||
special_peft_args = {"adapter_names"} | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
decoder_input_ids=decoder_input_ids, | ||
decoder_attention_mask=decoder_attention_mask, | ||
decoder_inputs_embeds=decoder_inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
batch_size = _get_batch_size(input_ids, inputs_embeds) | ||
if decoder_attention_mask is not None: | ||
|
@@ -1396,7 +1436,10 @@ def generate(self, **kwargs): | |
) | ||
try: | ||
if not peft_config.is_prompt_learning: | ||
outputs = self.base_model.generate(**kwargs) | ||
special_peft_args = {"adapter_names"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
with self._enable_peft_forward_hooks(**kwargs): | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
outputs = self.base_model.generate(**kwargs) | ||
else: | ||
if "input_ids" not in kwargs: | ||
raise ValueError("input_ids must be provided for Peft model generation") | ||
|
@@ -1541,18 +1584,21 @@ def forward( | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
if not peft_config.is_prompt_learning: | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
with self._enable_peft_forward_hooks(**kwargs): | ||
special_peft_args = {"adapter_names"} | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
labels=labels, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
batch_size = _get_batch_size(input_ids, inputs_embeds) | ||
if attention_mask is not None: | ||
|
@@ -1719,17 +1765,21 @@ def forward( | |
if not peft_config.is_prompt_learning: | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
start_positions=start_positions, | ||
end_positions=end_positions, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
with self._enable_peft_forward_hooks(**kwargs): | ||
special_peft_args = {"adapter_names"} | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
start_positions=start_positions, | ||
end_positions=end_positions, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
batch_size = _get_batch_size(input_ids, inputs_embeds) | ||
if attention_mask is not None: | ||
|
@@ -1893,15 +1943,19 @@ def forward( | |
if not peft_config.is_prompt_learning: | ||
if peft_config.peft_type == PeftType.POLY: | ||
kwargs["task_ids"] = task_ids | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
with self._enable_peft_forward_hooks(**kwargs): | ||
special_peft_args = {"adapter_names"} | ||
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args} | ||
return self.base_model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
**kwargs, | ||
) | ||
|
||
batch_size = _get_batch_size(input_ids, inputs_embeds) | ||
if attention_mask is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examle -> example