-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing multiple LoRA in the same batch or vit #1990
fixing multiple LoRA in the same batch or vit #1990
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for following up on your issue. I can see that you put some thought into finding an elegant solution and I think it would be a viable solution. However, I wonder if there is another way, as sketched in my comment, please check. But I may well have missed something, LMK if my suggestion is not working.
No problem! Glad to be of any help. |
Wow, it's gone! No idea what happened, I did write it for sure... Okay, so a second time. I was referring to these lines: Are those the only real change to the peft/src/peft/tuners/lora/model.py Lines 434 to 438 in 4611034
But maybe I'm missing something and more changes are necessary, or will be in the future for the issue. WDYT? |
Yes, your suggestion worked perfectly! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update, I think this version looks really nice.
I have added some comments where I think the code needs some further adjustments, please take a look. Also, please ensure to run make style
on the PR. Apart from that, I have two more requests:
- Let's update the docs to mention that
modules_to_save
is supported, but add the necessary caveats (depending on what we end up having in the code). - The unit tests should be updated to check for this use case. We don't need to test each type of possible model, but maybe a case similar to the original one with a classifier layer at the end and then perhaps one with two
modules_to_save
, say embedding and LM head. The tests could go here. LMK if you feel like giving this a try.
src/peft/utils/other.py
Outdated
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
if "adapter_names" not in kwargs.keys(): | ||
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
# Batches requests with similar LoRAs into microbatches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to a sub-method, similar to how we do this for LoRA:
peft/src/peft/tuners/lora/layer.py
Line 327 in 4611034
def _mixed_batch_forward( |
Also, with this added, I think it makes sense to have a similar method as in LoRA to check the arguments:
peft/src/peft/tuners/lora/layer.py
Line 302 in 4611034
def _check_forward_args(self, x, *args, **kwargs): |
Of course, we have to be careful not to be too restrictive here, given the other issue that you raised, and since the underlying module could be of any type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of the functions are added in the new commit, please check that.
src/peft/utils/other.py
Outdated
|
||
results = [0 for i in range(len(batch))] | ||
for i, active_adapter in enumerate(unique_adapters): | ||
sub_batch = batch[sub_batch_indices_list[i]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, here we assume that there is only 1 args
, as any other args
would be dropped, right? Also, what if other args
or kwargs
need to be sliced? We don't really know that so I think the best we can do is make a guess.
One suggestion that I have:
Check all args
and kwargs
if they're tensors and if they are a tensor, that they have the same length (i.e. batch size). In that case, slice those too. Otherwise, leave them as is. It's not perfect but I'm not sure what else could be done. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the input definition in the new version with x as the input to avoid the problems that you mentioned.
Sure, I'll be happy to add the tests. I'll add the updates when I get a chance. |
@saeid93 Do you still plan on working on this? |
Yes, sorry I had some busy weeks, I'll work on it on this weekend if there isn't a strict deadline. |
Thanks. No worries about the time, I just wanted to ensure that you're still on it. If not, that's also okay, just let me know. |
@BenjaminBossan I added the tests and also changed the code according to your comments, please let me know if further changes are required. About the docs, I wasn't sure if it is still necessary to add anything as I don't see any caveat regarding using module_to_save and the assumption is that modules_to_save is supported out of the box. But please let me know if you still think it should be updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the adjustments. I think there is still a bit of an issue to figure out when it comes to the signature of the mixed batch forward
call. Please check my comments.
@saeid93 LMK when this is ready for review. |
Gentle ping @saeid93 |
… check of mixed batches
@BenjaminBossan sorry for the delay, I was on annual leave. I made the changes you asked for, please let me know if you need any more changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for resuming the work and making the requested adjustments. Overall this looks good, I only have a few small comments. Could you please check?
Also, could you please run make style
to silence the linter? Note that by now, we've reached ruff 0.6.5
, so you may have to upgrade its version in your environment.
@BenjaminBossan sure, glad to be of any help. I think all the comments have now been applied. I ran the style every time, probably it was the version mismatch, hopefully, it will go through this time. Let me know if further changes are needed. |
Thanks for the latest changes @saeid93. The style check is still failing, did you successfully run |
@BenjaminBossan no problem! I did run the style, I'm not sure why it wasn't caught locally. I just removed the line. Hopefully, it will go through this time. |
The linter is happy now 🎉. However, now we get an error with Python 3.8 because it does not support |
Done 👍 I had to make a small change and use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for extending the functionality of having different adapters in the same batch to modules_to_save
. The changes look good, are well covered by tests, and documented. Nothing more to add!
Extend the functionality of having different adapters in the same batch to also work with `modules_to_save`.
@BenjaminBossan
This is the initial fix for fixing the batched LoRA inference problem explained #1960 .
For now, this only supports the vit model. This is an example and making a template for adding support for other models gradually. Generally there are two options to solve this according to the solution we discussed #1960 (comment) since we need to change model specific details:
I'm doing the second approach but each model needs different changes. Also,
generate
functions for generative models should be added. I'm happy to go through models one by one and also fix #1967, but it is better to review this first and then decide whether we want to go down this route of dynamically patching the forward functions or fixing it in the transformers library.