Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Mixing different LoRA adapters in same batch #1558

Merged

Conversation

BenjaminBossan
Copy link
Member

This PR tries to revive the work by Sourab in #903. The core logic is the same between the two PRs. This one should be more complete.

The main idea is to allow the user to mix different LoRA adapters in the same batch. This is useful when the user wants perform inference with a batch that uses different LoRA adapters. Without this, each batch would have to be restricted to the same LoRA adapter(s).

This PR should encompass:

  • all task types
  • all LoRA layer types
  • bnb layers

Extensive tests were added, as well as documentation.

This PR tries to revive the work by Sourab in huggingface#903. The core logic is
the same between the two PRs. This one should be more complete.

The main idea is to allow the user to mix different LoRA adapters in the
same batch. This is useful when the user wants perform inference with a
batch that uses different LoRA adapters. Without this, each batch would
have to be restricted to the same LoRA adapter(s).

This PR should encompass:

- all task types
- all LoRA layer types
- bnb layers

Extensive tests were added, as well as documentation.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clean and cool - thanks so much @BenjaminBossan ! I left few comments that are not blockers - what do you think ?

def forward(self, *args: Any, **kwargs: Any):
"""
Forward pass of the model.
"""
return self.get_base_model()(*args, **kwargs)
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make that as a class attribute ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def generate(self, *args, **kwargs):
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -903,19 +925,22 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
peft_config = self.active_peft_config
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1396,7 +1436,10 @@ def generate(self, **kwargs):
)
try:
if not peft_config.is_prompt_learning:
outputs = self.base_model.generate(**kwargs)
special_peft_args = {"adapter_names"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 387 to 393
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = target.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just loop here over the named_modules directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

docs/source/developer_guides/lora.md Outdated Show resolved Hide resolved
docs/source/developer_guides/lora.md Outdated Show resolved Hide resolved
docs/source/developer_guides/lora.md Show resolved Hide resolved
BenjaminBossan and others added 2 commits March 15, 2024 11:54
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
- make special_peft_forward_args an instance attribute
- simplify loop

Also:

- make test asserts more terse
Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, your comments should be addressed now.

docs/source/developer_guides/lora.md Show resolved Hide resolved
def forward(self, *args: Any, **kwargs: Any):
"""
Forward pass of the model.
"""
return self.get_base_model()(*args, **kwargs)
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def generate(self, *args, **kwargs):
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -903,19 +925,22 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
peft_config = self.active_peft_config
special_peft_args = {"adapter_names"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1396,7 +1436,10 @@ def generate(self, **kwargs):
)
try:
if not peft_config.is_prompt_learning:
outputs = self.base_model.generate(**kwargs)
special_peft_args = {"adapter_names"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 387 to 393
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = target.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @BenjaminBossan !

Copy link

@Noezor Noezor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool feature ! Left 2 significant comments about use-case and comparison, and 3 nits on typo

Comment on lines 283 to 284
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan on having some benchmarking of this feature? (may not even need vs other methods, but at least what throughput per batch size looks like on a transformer with different LoRAs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Noezor, as per the previous PR below is a simple benchmarking experiment.
Screenshot 2024-03-18 at 4 27 23 PM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the reference.

if x.dtype != compute_dtype:
x = x.to(compute_dtype)

# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear

I think the code is v clean, but if this comment exists, maybe we could add a comment on the sub-batching logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean, could you please elaborate?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sub-batch logic may not be trivial, and I feel that if we detail what is in this comment, we may also want to add a comment on the batching?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yes, it may take a minute to comprehend, but I have a hard time thinking of a comment that would make it easier vs just reading the code. Just paraphrasing what the code does would not really help IMO.

embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]

# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# getting the sub-batch, passing it ot LoRA layers and updating the corresponding indices of the linear
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +2238 to +2243
time_base, time0, time1, *time_mixed = logs
time_non_mixed = (time_base + time0 + time1) / 3
time_mixed = min(time_mixed)

factor = 2.0
assert time_mixed < factor * time_non_mixed
Copy link

@Noezor Noezor Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by this test as I feel like this is not the situation users would be interested in when using this feature. I think what we want is to compare that the total computation without batching with 3 different LoRAs is slower than the total computation with batching of 3 different LoRAs and thus having a factor < 1

Something like instead

        with timed():
            output_base = base_model(**inputs[::3]).logits
        with timed():
            output0 = peft_model(**inputs[1::3]).logits
        with timed():
            output1 = peft_model(**inputs[2::3]).logits

and finally the comparison being

        time_base, time0, time1, *time_mixed = logs
        time_non_mixed = (time_base + time0 + time1)
        time_mixed = min(time_mixed)

        factor = 1.0 # this is probably not very agressive but test flakiness is also important
        assert time_mixed < factor * time_non_mixed

WDYT? (bonus is that it does a free benchmark to include in the changelog)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the purpose of this check was not really to ensure that this is faster than if the user did the calls separately. In fact, if the users are buffering the same adapter into batches of the same size, I would actually expect that to be faster. In my mind, this feature is more of a convenience feature to avoid that type of extra work. Therefore, this test is more of a regression test to ensure we don't do something unreasonably slow in the implementation.

What you suggest is to benchmark this new feature vs a "naive" alternative which effectively runs at a lower batch size and is thus expected to be slower. I agree that this is still useful to test in the sense that some users might otherwise do exactly that, but I don't see it as a replacement to the existing test. Therefore, I added it to the test but kept the existing benchmark.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the users are buffering the same adapter into batches of the same size, I would actually expect that to be faster

I don't think so? If you have 3 batches with 3 different adapters running separatly (as suggested in my comment above) vs using the method implemented here, the "base" will run 3x, vs just once for the new method. Note that this is also the results that @pacman100 is showing under. This is super powerful.

The way I'm reading the test, its purpose is to make sure that optimized batched Lora inference with 3 different adapters on batch size n is roughly smaller than the average of 3 inferences with 3 different LoRAs of batch size n, but I don't know why this is relevant. The quantities computed are quite different.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, feel free to merge as is if you disagree. But I believe this is an important feature, and having tests that clearly expose usage would help it being more known and used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so? If you have 3 batches with 3 different adapters running separatly (as suggested in my comment above) vs using the method implemented here, the "base" will run 3x, vs just once for the new method

Yes, you are correct here, sorry for my misunderstanding. I still think there is a tradeoff, especially if we have large batch size + large number of LoRA adapters, which would result in an effectively very small batch size for the LoRA-part of the forwards call, where this approach could actually be slower, but this may be more of an edge case.

but I don't know why this is relevant. The quantities computed are quite different.

Yes, this is not comparing the same output. As mentioned, this is not a test from a user perspective, but it's purely there to ensure that we don't have a speed regression of the implementation. The thought is that we expect mixed batches to be slower than pure batches, but we don't want them to be too much slower. Does that make sense?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for improving upon the previous PR for enabling batched inference with mixed LoRA adapters, make it feature complete with docs and tests! ✨

I agree with comment by @Noezor on the timing comparison on the timing test.

Comment on lines 283 to 284
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 283 to 284
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
- Take a look at alternative implementations such as [punica](https://github.com/punica-ai/punica) or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Noezor, as per the previous PR below is a simple benchmarking experiment.
Screenshot 2024-03-18 at 4 27 23 PM

- Typo
- Extend docs: lorax
- Add one more runtime performance test
@BenjaminBossan
Copy link
Member Author

The newly added benchmark turned out to be flaky. I'm not quite sure why, but on CPU, the timing differences are quite small. Therefore, I only run the test on GPU. Moreover, I run the timing 3 times and take an average to reduce variance.

@BenjaminBossan BenjaminBossan merged commit 91e4b08 into huggingface:main Mar 18, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the feat-multiple-loras-in-batch-2 branch March 18, 2024 14:50

Normally, each inference batch has to use the same adapter(s) in PEFT. This can sometimes be annoying, because we may have batches that contain samples intended to be used with different LoRA adapters. For example, we could have a base model that works well in English and two more LoRA adapters, one for French and one for German. Usually, we would have to split our batches such that each batch only contains samples of one of the languages, we cannot combine different languages in the same batch.

Thankfully, it is possible to mix different LoRA adapters in the same batch using the `adapter_name` argument. Below, we show an examle of how this works in practice. First, let's load the base model, English, and the two adapters, French and German, like this:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examle -> example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants