Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing multiple LoRA in the same batch or vit #1990

Merged
merged 10 commits into from
Sep 17, 2024

Conversation

saeid93
Copy link
Contributor

@saeid93 saeid93 commented Aug 5, 2024

@BenjaminBossan
This is the initial fix for fixing the batched LoRA inference problem explained #1960 .
For now, this only supports the vit model. This is an example and making a template for adding support for other models gradually. Generally there are two options to solve this according to the solution we discussed #1960 (comment) since we need to change model specific details:

  1. Change the signature of the model forward functions in the transformer library. The problem with this approach is that it needs Peft specific logic in transformers which I'm not sure is the best way for a general purpose library like transformers.
  2. Change the forward function in Peft and patch it dynamically when the multiple LoRA request are in the inference batch.

I'm doing the second approach but each model needs different changes. Also, generate functions for generative models should be added. I'm happy to go through models one by one and also fix #1967, but it is better to review this first and then decide whether we want to go down this route of dynamically patching the forward functions or fixing it in the transformers library.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for following up on your issue. I can see that you put some thought into finding an elegant solution and I think it would be a viable solution. However, I wonder if there is another way, as sketched in my comment, please check. But I may well have missed something, LMK if my suggestion is not working.

@saeid93
Copy link
Contributor Author

saeid93 commented Aug 6, 2024

No problem! Glad to be of any help.
Sorry, I couldn't find any comments on the pull request. Do you mean this comment on MobileVit issue? If so, the problem is different from this one, this is to solve this issue for multiple LoRA adapters. The other issue is a MobileVit specific problem.

@BenjaminBossan
Copy link
Member

Sorry, I couldn't find any comments on the pull request.

Wow, it's gone! No idea what happened, I did write it for sure...

Okay, so a second time. I was referring to these lines:

https://github.com/huggingface/peft/pull/1990/files#diff-b700510ad2034b549511a969d85f89f9094243a7f3c740e311dc1eb83ace9a79R57-R61

Are those the only real change to the forward function that are required? If yes, would it be possible to instead register a pre-forward hook for classifier to inject the argument? This could be easily achieved here:

for module in self.modules():
if isinstance(module, LoraLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

But maybe I'm missing something and more changes are necessary, or will be in the future for the issue. WDYT?

@saeid93
Copy link
Contributor Author

saeid93 commented Aug 6, 2024

Yes, your suggestion worked perfectly!
Please check the new commit.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, I think this version looks really nice.

I have added some comments where I think the code needs some further adjustments, please take a look. Also, please ensure to run make style on the PR. Apart from that, I have two more requests:

  1. Let's update the docs to mention that modules_to_save is supported, but add the necessary caveats (depending on what we end up having in the code).
  2. The unit tests should be updated to check for this use case. We don't need to test each type of possible model, but maybe a case similar to the original one with a classifier layer at the end and then perhaps one with two modules_to_save, say embedding and LM head. The tests could go here. LMK if you feel like giving this a try.

return self.modules_to_save[self.active_adapter](*args, **kwargs)
if "adapter_names" not in kwargs.keys():
return self.modules_to_save[self.active_adapter](*args, **kwargs)
# Batches requests with similar LoRAs into microbatches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to a sub-method, similar to how we do this for LoRA:

def _mixed_batch_forward(

Also, with this added, I think it makes sense to have a similar method as in LoRA to check the arguments:

def _check_forward_args(self, x, *args, **kwargs):

Of course, we have to be careful not to be too restrictive here, given the other issue that you raised, and since the underlying module could be of any type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of the functions are added in the new commit, please check that.


results = [0 for i in range(len(batch))]
for i, active_adapter in enumerate(unique_adapters):
sub_batch = batch[sub_batch_indices_list[i]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, here we assume that there is only 1 args, as any other args would be dropped, right? Also, what if other args or kwargs need to be sliced? We don't really know that so I think the best we can do is make a guess.

One suggestion that I have:

Check all args and kwargs if they're tensors and if they are a tensor, that they have the same length (i.e. batch size). In that case, slice those too. Otherwise, leave them as is. It's not perfect but I'm not sure what else could be done. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the input definition in the new version with x as the input to avoid the problems that you mentioned.

@saeid93
Copy link
Contributor Author

saeid93 commented Aug 7, 2024

Sure, I'll be happy to add the tests. I'll add the updates when I get a chance.

@BenjaminBossan
Copy link
Member

@saeid93 Do you still plan on working on this?

@saeid93
Copy link
Contributor Author

saeid93 commented Aug 20, 2024

@saeid93 Do you still plan on working on this?

Yes, sorry I had some busy weeks, I'll work on it on this weekend if there isn't a strict deadline.

@BenjaminBossan
Copy link
Member

Yes, sorry I had some busy weeks, I'll work on it on this weekend if there isn't a strict deadline.

Thanks. No worries about the time, I just wanted to ensure that you're still on it. If not, that's also okay, just let me know.

@saeid93
Copy link
Contributor Author

saeid93 commented Aug 24, 2024

@BenjaminBossan I added the tests and also changed the code according to your comments, please let me know if further changes are required. About the docs, I wasn't sure if it is still necessary to add anything as I don't see any caveat regarding using module_to_save and the assumption is that modules_to_save is supported out of the box. But please let me know if you still think it should be updated.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the adjustments. I think there is still a bit of an issue to figure out when it comes to the signature of the mixed batch forward call. Please check my comments.

src/peft/utils/other.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
@saeid93 saeid93 marked this pull request as draft September 1, 2024 14:41
@BenjaminBossan
Copy link
Member

@saeid93 LMK when this is ready for review.

@BenjaminBossan
Copy link
Member

Gentle ping @saeid93

@saeid93 saeid93 marked this pull request as ready for review September 15, 2024 12:07
@saeid93
Copy link
Contributor Author

saeid93 commented Sep 15, 2024

@BenjaminBossan sorry for the delay, I was on annual leave. I made the changes you asked for, please let me know if you need any more changes.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for resuming the work and making the requested adjustments. Overall this looks good, I only have a few small comments. Could you please check?

Also, could you please run make style to silence the linter? Note that by now, we've reached ruff 0.6.5, so you may have to upgrade its version in your environment.

docs/source/developer_guides/lora.md Outdated Show resolved Hide resolved
docs/source/developer_guides/lora.md Outdated Show resolved Hide resolved
src/peft/utils/other.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
src/peft/utils/other.py Outdated Show resolved Hide resolved
@saeid93
Copy link
Contributor Author

saeid93 commented Sep 16, 2024

@BenjaminBossan sure, glad to be of any help. I think all the comments have now been applied. I ran the style every time, probably it was the version mismatch, hopefully, it will go through this time. Let me know if further changes are needed.

@BenjaminBossan
Copy link
Member

Thanks for the latest changes @saeid93. The style check is still failing, did you successfully run make style? If that doesn't work for some reason, removing this import should be sufficient.

@saeid93
Copy link
Contributor Author

saeid93 commented Sep 17, 2024

@BenjaminBossan no problem! I did run the style, I'm not sure why it wasn't caught locally. I just removed the line. Hopefully, it will go through this time.

@BenjaminBossan
Copy link
Member

The linter is happy now 🎉. However, now we get an error with Python 3.8 because it does not support list[str] etc. Could you please add the from __future__ import annotations import to the top of utils/other.py, that should fix it.

@saeid93
Copy link
Contributor Author

saeid93 commented Sep 17, 2024

Done 👍 I had to make a small change and use tuple instead of Tuple to avoid ruff complaints after adding from __future__ import annotations.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for extending the functionality of having different adapters in the same batch to modules_to_save. The changes look good, are well covered by tests, and documented. Nothing more to add!

@BenjaminBossan BenjaminBossan merged commit adf0a1d into huggingface:main Sep 17, 2024
14 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Sep 18, 2024
Extend the functionality of having different adapters in the same batch to also
work with `modules_to_save`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MobileViT does not work with Inference with different LoRA adapters in the same batch
2 participants