-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Mixing different LoRA adapters in same batch #1558
FEAT Mixing different LoRA adapters in same batch #1558
Conversation
This PR tries to revive the work by Sourab in huggingface#903. The core logic is the same between the two PRs. This one should be more complete. The main idea is to allow the user to mix different LoRA adapters in the same batch. This is useful when the user wants perform inference with a batch that uses different LoRA adapters. Without this, each batch would have to be restricted to the same LoRA adapter(s). This PR should encompass: - all task types - all LoRA layer types - bnb layers Extensive tests were added, as well as documentation.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very clean and cool - thanks so much @BenjaminBossan ! I left few comments that are not blockers - what do you think ?
src/peft/peft_model.py
Outdated
def forward(self, *args: Any, **kwargs: Any): | ||
""" | ||
Forward pass of the model. | ||
""" | ||
return self.get_base_model()(*args, **kwargs) | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make that as a class attribute ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/peft_model.py
Outdated
|
||
def generate(self, *args, **kwargs): | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/peft_model.py
Outdated
@@ -903,19 +925,22 @@ def forward( | |||
): | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
peft_config = self.active_peft_config | |||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/peft_model.py
Outdated
@@ -1396,7 +1436,10 @@ def generate(self, **kwargs): | |||
) | |||
try: | |||
if not peft_config.is_prompt_learning: | |||
outputs = self.base_model.generate(**kwargs) | |||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/tuners/lora/model.py
Outdated
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] | ||
for key in key_list: | ||
_, target, _ = _get_submodules(self.model, key) | ||
if isinstance(target, LoraLayer): | ||
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) | ||
handle = target.register_forward_pre_hook(pre_forward, with_kwargs=True) | ||
hook_handles.append(handle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just loop here over the named_modules directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
- make special_peft_forward_args an instance attribute - simplify loop Also: - make test asserts more terse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback, your comments should be addressed now.
src/peft/peft_model.py
Outdated
def forward(self, *args: Any, **kwargs: Any): | ||
""" | ||
Forward pass of the model. | ||
""" | ||
return self.get_base_model()(*args, **kwargs) | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/peft_model.py
Outdated
|
||
def generate(self, *args, **kwargs): | ||
with self._enable_peft_forward_hooks(*args, **kwargs): | ||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/peft_model.py
Outdated
@@ -903,19 +925,22 @@ def forward( | |||
): | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
peft_config = self.active_peft_config | |||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/peft_model.py
Outdated
@@ -1396,7 +1436,10 @@ def generate(self, **kwargs): | |||
) | |||
try: | |||
if not peft_config.is_prompt_learning: | |||
outputs = self.base_model.generate(**kwargs) | |||
special_peft_args = {"adapter_names"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/tuners/lora/model.py
Outdated
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] | ||
for key in key_list: | ||
_, target, _ = _get_submodules(self.model, key) | ||
if isinstance(target, LoraLayer): | ||
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) | ||
handle = target.register_forward_pre_hook(pre_forward, with_kwargs=True) | ||
hook_handles.append(handle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much @BenjaminBossan !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool feature ! Left 2 significant comments about use-case and comparison, and 3 nits on typo
docs/source/developer_guides/lora.md
Outdated
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters. | ||
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan on having some benchmarking of this feature? (may not even need vs other methods, but at least what throughput per batch size looks like on a transformer with different LoRAs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also worth mentioning https://github.com/predibase/lorax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Noezor, as per the previous PR below is a simple benchmarking experiment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the reference.
src/peft/tuners/lora/bnb.py
Outdated
if x.dtype != compute_dtype: | ||
x = x.to(compute_dtype) | ||
|
||
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear | |
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/tuners/lora/bnb.py
Outdated
expected_dtype = result.dtype | ||
x = x.to(lora_A.weight.dtype) | ||
|
||
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear | |
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear |
I think the code is v clean, but if this comment exists, maybe we could add a comment on the sub-batching logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean, could you please elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sub-batch logic may not be trivial, and I feel that if we detail what is in this comment, we may also want to add a comment on the batching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yes, it may take a minute to comprehend, but I have a hard time thinking of a comment that would make it easier vs just reading the code. Just paraphrasing what the code does would not really help IMO.
src/peft/tuners/lora/layer.py
Outdated
embedding_B = self.lora_embedding_B[active_adapter].T | ||
scaling = self.scaling[active_adapter] | ||
|
||
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear | |
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
time_base, time0, time1, *time_mixed = logs | ||
time_non_mixed = (time_base + time0 + time1) / 3 | ||
time_mixed = min(time_mixed) | ||
|
||
factor = 2.0 | ||
assert time_mixed < factor * time_non_mixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by this test as I feel like this is not the situation users would be interested in when using this feature. I think what we want is to compare that the total computation without batching with 3 different LoRAs is slower than the total computation with batching of 3 different LoRAs and thus having a factor
< 1
Something like instead
with timed():
output_base = base_model(**inputs[::3]).logits
with timed():
output0 = peft_model(**inputs[1::3]).logits
with timed():
output1 = peft_model(**inputs[2::3]).logits
and finally the comparison being
time_base, time0, time1, *time_mixed = logs
time_non_mixed = (time_base + time0 + time1)
time_mixed = min(time_mixed)
factor = 1.0 # this is probably not very agressive but test flakiness is also important
assert time_mixed < factor * time_non_mixed
WDYT? (bonus is that it does a free benchmark to include in the changelog)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the purpose of this check was not really to ensure that this is faster than if the user did the calls separately. In fact, if the users are buffering the same adapter into batches of the same size, I would actually expect that to be faster. In my mind, this feature is more of a convenience feature to avoid that type of extra work. Therefore, this test is more of a regression test to ensure we don't do something unreasonably slow in the implementation.
What you suggest is to benchmark this new feature vs a "naive" alternative which effectively runs at a lower batch size and is thus expected to be slower. I agree that this is still useful to test in the sense that some users might otherwise do exactly that, but I don't see it as a replacement to the existing test. Therefore, I added it to the test but kept the existing benchmark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the users are buffering the same adapter into batches of the same size, I would actually expect that to be faster
I don't think so? If you have 3 batches with 3 different adapters running separatly (as suggested in my comment above) vs using the method implemented here, the "base" will run 3x, vs just once for the new method. Note that this is also the results that @pacman100 is showing under. This is super powerful.
The way I'm reading the test, its purpose is to make sure that optimized batched Lora inference with 3 different adapters on batch size n is roughly smaller than the average of 3 inferences with 3 different LoRAs of batch size n, but I don't know why this is relevant. The quantities computed are quite different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, feel free to merge as is if you disagree. But I believe this is an important feature, and having tests that clearly expose usage would help it being more known and used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? If you have 3 batches with 3 different adapters running separatly (as suggested in my comment above) vs using the method implemented here, the "base" will run 3x, vs just once for the new method
Yes, you are correct here, sorry for my misunderstanding. I still think there is a tradeoff, especially if we have large batch size + large number of LoRA adapters, which would result in an effectively very small batch size for the LoRA-part of the forwards call, where this approach could actually be slower, but this may be more of an edge case.
but I don't know why this is relevant. The quantities computed are quite different.
Yes, this is not comparing the same output. As mentioned, this is not a test from a user perspective, but it's purely there to ensure that we don't have a speed regression of the implementation. The thought is that we expect mixed batches to be slower than pure batches, but we don't want them to be too much slower. Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @BenjaminBossan for improving upon the previous PR for enabling batched inference with mixed LoRA adapters, make it feature complete with docs and tests! ✨
I agree with comment by @Noezor on the timing comparison on the timing test.
docs/source/developer_guides/lora.md
Outdated
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters. | ||
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also worth mentioning https://github.com/predibase/lorax
docs/source/developer_guides/lora.md
Outdated
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters. | ||
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Noezor, as per the previous PR below is a simple benchmarking experiment.
- Typo - Extend docs: lorax - Add one more runtime performance test
The newly added benchmark turned out to be flaky. I'm not quite sure why, but on CPU, the timing differences are quite small. Therefore, I only run the test on GPU. Moreover, I run the timing 3 times and take an average to reduce variance. |
|
||
Normally, each inference batch has to use the same adapter(s) in PEFT. This can sometimes be annoying, because we may have batches that contain samples intended to be used with different LoRA adapters. For example, we could have a base model that works well in English and two more LoRA adapters, one for French and one for German. Usually, we would have to split our batches such that each batch only contains samples of one of the languages, we cannot combine different languages in the same batch. | ||
|
||
Thankfully, it is possible to mix different LoRA adapters in the same batch using the `adapter_name` argument. Below, we show an examle of how this works in practice. First, let's load the base model, English, and the two adapters, French and German, like this: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examle -> example
This PR tries to revive the work by Sourab in #903. The core logic is the same between the two PRs. This one should be more complete.
The main idea is to allow the user to mix different LoRA adapters in the same batch. This is useful when the user wants perform inference with a batch that uses different LoRA adapters. Without this, each batch would have to be restricted to the same LoRA adapter(s).
This PR should encompass:
Extensive tests were added, as well as documentation.