-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible to build a LoRA that doesn't inject into the transformer? #1523
Comments
It would be possible to build a LoRA adapter that works purely through forward hooks on the base model, but that's a big difference to how we implement this right now and also not as flexible. Regarding your issue, do you know that you can disable the LoRA-adapters completely and then the model behaves like the pure base model? This should allow you to avoid loading the model twice: with model.disable_adapter():
# do inference on base model |
Interesting, I did not know that. However, wouldn’t that give different weights to the prediction head, probably resulting in errors since it wasn't trained to recognize those weights? What I was hoping for was something where each addition use of peft lora only incurred a space cost of |A|+|B|, the same size as the saved weights, but it sounds like such a thing currently doesn't exist, is that correct? If it makes it any easier, I would not be expecting to train from that state, just run inference. Thanks for the fast reply! |
I think I don't quite get your problem yet. My initial understanding was that you need to have the base model for inference and the LoRA-augmented model for inference and wanted to avoid loading the base model twice. Are the prediction heads trained with or without LoRA? If you train them with LoRA, you have probably added them to
I'm not sure I understand. When you load the base model, it takes 3GB as you mentioned. When you load LoRA on top, that should increase memory only by a tiny amount. My earlier understanding was that the additional memory comes from loading the base model twice. |
We might be misunderstanding on our end, but our belief was the underlying model gets changed by loading LoRA. So if we load two separate peft models on top of the same transformer, the two loadings of peft models will clobber each other:
We checked that the weights in |
LoRA adapters do indeed mutate the base model, but it doesn't "clobber" it. We took care of making it so that the base model itself can still make predictions as if LoRA was not there when LoRA is being disabled, and different LoRA adapters do play nicely with each other. Note, however, that I would load the adapter a bit differently than you do, not sure if your code would work correctly: from peft import PeftModel
base_model = AutoModel...
# load adapter 0, it is automatically the active adapter
peft_model = PeftModel.from_pretrained(base_model, <path-to-adapter0>, adapter_name=<name0>)
# load adapter 1, adapter 0 is still active
peft_model.load_adapter(<path-to-adapter1>, adapter_name=<name1>)
# activate adapter 1, deactivate adapter 0
peft_model.set_adapter(<name1>) |
The bolded part is a little worrisome - does that mean that I can verify this myself later tonight, although I don't have time to do so this early afternoon. |
Can confirm there is an issue with our software loading the transformer and then having that transformer overwritten when connecting with LoRA using peft. We have a POS model and a sentiment model which both use electra-large, except the sentiment model uses a peft wrapper to get better results. We weren't able to figure out how to get better POS tags with peft, but that's a story for another day. If I load the POS model by itself, it gets an overall score of 93.92 on the EWT UD dataset (those exact details aren't super relevant). If I load both the POS model and the Sentiment model, using two different invocations of electra-large, it still gets that score. If instead I load the POS model, then reuse the electra-large transformer to load the Sentiment model, the score drops to 92.12. Our current release uses the first scheme to keep the results consistent, but that results in 6GB of transformer instead of 3GB. In terms of separation of concerns, I'd love to have a mechanism where the POS model can load its transformer, the Sentiment model can reuse the same transformer and wrap it in peft, and the two models wouldn't have to do anything specific to their transformer before using it. |
Just so I understand correctly, the same base model (electra large) is used by the POS model and the sentiment model. You load LoRA weights on top of the sentiment model. In that case, indeed, it is expected that the POS model also behaves differently because PEFT will modify the base model by attaching the adapters. If you want to avoid having a copy of the base model in memory, I would recommend checking if disabling the adapters on the sentiment model restores the performance of the POS model. |
It's actually giving me an error, saying that there is no adapter loaded:
I saw your earlier suggestion of using |
Did you call Also, as I'm reading your code again, how does the |
Yes, definitely. This crash happens if I do this:
I'm not giving it a name as part of the
This works, actually, and we get back the expected scores! This is a bit challenging to work with, though, for a couple reasons. The simplest problem is, what if we have two Sentiment models or other model that uses a peft adapter? We've tried it with POS, although we haven't yet found settings that consistently improve the POS scores, and there are definitely situations where we need multiple POS models built from the same transformer. They can't all be named There's also the problem of thread safety and needing to modify the transformer's state (currently active adapter) in order to use it. That might actually be a problem if the proposal to remove GIL is implemented in python I know there's a trick to call What I'm getting at is that
A possible issue here is that there are multiple applications for having several constituency models loaded at once, for example using those models in an ensemble to get better results. Again we would run into the limitation of how to name the adapters, such as if the first constituency parser loads its peft adapter with the name "constituency", then the second constituency parser would need to load it with the name "constituency-2" etc. Would that be something where the constituency parser itself knows that there have been N previously loaded models? Would the caller need to keep track of that? Either way, that seems a lot more complicated and less clean than simply having a fresh transformer object (hopefully sharing the underlying weights to save GPU space) which doesn't need to know anything about previously loaded adapters. |
You can load in multiple adapters and give them any name you want (usually using the
That's still far in the future and will be opt-in. |
The thread safety concerns may be further off than we'd like, but the unique names solution doesn't really address the issue of situations where we need more than one POS model or more than one constituency model. We could enforce uniqueness of names somehow, but I was hoping to not add more to the global state than necessary. I suppose a random 10 letter name would almost never repeat and wouldn't require any global state at all... It would also be a little weird for each model to have the workflow of first turn on their own adapter, then run inference, then turn off their adapter (the latter step being necessary so that any model still using the raw version of the transformer doesn't have to figure out if there even are any adapters to turn off). Currently what does appear to work for loading time, but not GPU memory, is to load the transformer into memory and then clone it N times for each of the adapters we need. Then we get a situation where one annotator doesn't have to care at all how many other annotators of different types or the same type exist, but at the cost of increased GPU memory usage. If it's not currently a feature, is it at all feasible to make a wrapper which is just a transformer and a single adapter, where that adapter does not affect inference for other users of the underlying transformer? If that's not anywhere on the current project roadmap, is it something where you'd consider merging a PR that implemented a feature like that? |
I don't get this point. You should know which adapters exist beforehand, so you can just choose some static names like "sentiment", "pos0", "pos1", or not?
Yes, I agree that it's not super convenient and, if you have to switch each time a new sample comes in, this could add some overhead. As you mention, the alternative would be to have a copy of the model for each adapter in memory. I think the feature we planned to add in #903 would have helped in your situation, but we didn't pursue it further.
Unfortunately, that's not easy to achieve. The underlying PyTorch model does not have sufficient flexibility (say, via hooks) to add all the features that we need without mutating the model itself. It would probably be possible to have what you asked when focusing on only the subset of features that you need, but you'd have to build that yourself. |
Not necessarily true. We provide a way for people to make an annotation pipeline, and we have no control over how many times the user makes a new pipeline with the same base models without keeping some form of global state. Still, we did implement a cache system for not excessively loading the same word vectors or transformer too many times. Perhaps that would be a reasonable place to keep track of which annotator names have been used in the past as well.
How much overhead? Is it rewriting tensors or just flipping a pointer? |
It's flipping a flag, but on each module, so we have to iterate through all the modules to do this. It should still be cheap compared to inference, but just a heads up to not do it excessively. |
Thanks for the shout out! It looks like a useful feature if running things in small batches, or if we ran the adapters on a per sentence basis rather than a per annotator basis Actually, in a single Pipeline we generally run inference with runtime life cycle of
The other main use case is multiple of the same type of model at the same time, where admittedly the switching becomes more frequent between batches. Still, I would think that given your description above, 50 sentences and then switching to another wrapper would only be slightly more expensive than 50 sentences by themselves. So I'm coming around to the idea of, per Pipeline, there would be one copy of the transformer used by that Pipeline, and each annotator would know to set the adapter for themselves and not worry about what any other annotator did. As previously whinged, that's not thread-safe, but maybe there will be another solution available by the time thread safety is an issue. There is one issue I have though - how to figure out if a transformer even has adapters attached? As I mentioned above, if I do this to a transformer with no adapters, it throws an exception:
Is there a way to check first if there are adapters? Or maybe just not have it throw that exception? I don't think there is a downside to having a model that has no adapters just ignore a call to |
You can check if |
Does calling the forward pass on a transformer respect the active adapter? If not, how do I go about getting back the same values (the transformer used as a featurizer) once there is an adapter loaded? I would have expected the following little program to output a few different weights, but outputs the same ones each time. If it's not clear from the script where the mismatched expectations are occurring, I can point to the model files in question (they're both available on HF under Stanford's Stanza repos, FWIW) Is there something I need to do differently to ensure that I get weights with the transformer's adapter activated in this case?
OUTPUT
|
My thinking here was that I could load several adapters onto the same transformer and switch between them to get the needed encoding for each task, but the switching isn't actually doing what I'd expect. However, if I don't use an adapter name and just leave it to be the |
I do note that I never called |
In terms of calling
This doesn't work
If I do this:
Now I get seemingly random output from |
If I try this instead
This keeps giving random output as well, and again there's no way to disable the adapters... but weirdly, my models give the same final results when run on test sets... Still, is there an example I can use or a way to turn that short script above into something where I can easily switch between either the two adapters or a no-adapter form of the transformer? |
The issue with your code is a combination of a few, some of it using the PEFT API incorrectly, and some of them the fact that a fresh LoRA adapter is a no-op by default, so it does not affect the result. Below is some code that shows how to use this correctly, I hope it helps to solve your issue: import torch
from peft import get_peft_model, TaskType, PeftModel, LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.manual_seed(0)
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer.encode("This is a test", add_special_tokens=False, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(model_id).eval()
outputs = model(inputs, output_hidden_states=True)
print("- base model output")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())
# by setting init_lora_weights to False, we ensure that it's not a no-op
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
outputs = model(inputs, output_hidden_states=True)
print("- peft model output default adapter")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())
# add another adapter
config = LoraConfig(r=32, lora_alpha=32, task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model.add_adapter("adapter2", config)
model.set_adapter("adapter2")
outputs = model(inputs, output_hidden_states=True)
print("- peft model output adapter 2")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())
print("saving and loading")
model.save_pretrained("/tmp/issue-1523")
del model
model = AutoModelForCausalLM.from_pretrained(model_id).eval()
outputs = model(inputs, output_hidden_states=True)
print("- loaded model output")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())
model = PeftModel.from_pretrained(model, "/tmp/issue-1523")
outputs = model(inputs, output_hidden_states=True)
print("- loaded peft model output")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())
model.load_adapter("/tmp/issue-1523/adapter2", adapter_name="adapter2")
model.set_adapter("adapter2")
outputs = model(inputs, output_hidden_states=True)
print("- loaded peft model output adapter 2")
print(torch.linalg.norm(outputs.hidden_states[-1]).item()) This gives me:
As to the difference between Regarding |
Ah, great, I didn't see this with previous versions (or maybe I missed it), but the latest version of the peft integration has the ability to call
I will have to check that this allows for training to continue if I do have two minor complaints about the interface - not sure how easy it would be to fix at this point, seeing as how these modules are both publicly released. In
whereas in
so for |
Another minor interface complaint: after doing this, I get back the original encoding for the text, which is great. That's exactly what we need. However, the following snippet:
It would be great to have utility methods on the model which indicated if any peft integrations are currently attached (you mentioned checking |
Here's something weird, but possibly known / not important. When created with When loaded with At any rate, I can confirm that they both get loaded as expected when put into a new instance of the transformer w/ peft, and furthermore I can train them after reloading a checkpoint, even if I loaded that checkpoint with Does that sound like a reasonable approach to take? There's no real reason to have different code paths for loading for training or loading for eval, is there? |
When building an optimizer for a transformer model with peft on it, is there a way to only get the optimizer state for the active adapter? Currently, when I call Alternatively, a way to remove an adapter might be sufficient for my needs. To explain the problem I'm running into:
|
This statement confuses me. If I want to use the original model w/o adapters after having loaded an adapter, I think I need to call However, I will say this doesn't work in the case of creating an adapter via In such a case, is it possible to turn off the peft adapter in any way? I really do think this is a bug, that there should be a way to
|
Indeed it would be better to have more consistency here, but as you mentioned, we can't change that without breaking existing code, so it is how it is.
Yes, indeed. It's not quite as easy as it sounds, because theoretically, some modules belonging to the adapter could be enabled and others disabled (although the current API does not expose this possibility, users could still do this manually). If I have some time on my hands, I'll think of something.
Using
Indeed, filtering by name is the way to go here. Directly after loading, you could also filter by
When you create the model with peft_model = get_peft_model(...)
# default adapter is active
with peft_model.disable_adapter():
# inference without adapters
# now, default adapter is active again |
Ah, I can see how worrying about users can be a problem. You've treated this user quite nicely so far, at least. Thanks in advance for any progress you can make on adding an interface. My current working solution is to set the active adapter with each new batch, just in case a previous batch used a different annotator and therefore a different adapter. You mentioned this might add some overhead, but it doesn't make a noticeable different in our annotation speed (14s w/ or w/o this change for the EWT dataset, for example).
Gotcha, thanks. The approach I've been working on is to make sure the transformer used for the training is not used for anything else, eg a separate copy from the one used for the POS retagging in the constituency or dependency case. I think that should avoid all such problems.
Thanks. Is there a downside to setting the |
Not from the top of my head, but I haven't really "mixed" the transformers use of PEFT with the PEFT use. |
Thanks again for all your help! I think we're good with using multiple adapters on the same transformer now. If there's any updates to the interface that include shortcuts for checking whether or not any adapters are actually active (you mentioned it being a recursive call for now), that would make things a bit faster for us. Also, if there's ever a way to use a single adapter on a shallow copy of the transformer (eg, no deep copy of the weights), that would also simplify our usage quite a bit. |
I'm currently working on #1663, which goes somewhat in this direction, but I'm not sure if it 100% fits your use case. Maybe you can take a look. |
That does look relevant in terms of making the switching between adapters faster / simpler. Thanks! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Feature request
Is it possible to build a LoRA that doesn't inject into the transformer? This would allow for reusing the same basic transformer with multiple adapters in the same process while saving on GPU memory (probably at the expense of some speed)
Motivation
We've started using PEFT with LoRA for tasks such as sentiment analysis and constituency parsing in Stanza, and one thing we found is that there is currently no memory savings compared to using a fully finetuned transformer.
For example, if the transformer loaded for sentiment analysis takes 3GB, with no finetuning we can reuse the same transformer weights when constituency parsing, making for a total of 3GB plus the prediction heads of the models. If we use fully FT transformers, obviously that increases to 6GB assuming those are our only two tasks.
PEFT with LoRA uses
inject_adapter_in_model
to update the model with the As and Bs, AFAIK, meaning that loading those two models still takes 6GB. If we could have a version of the transformer which does inference with the As and Bs not injected, but wrapping the base transformer's tensors, this would almost certainly be noticeably slower but would allow for a much smaller memory footprint.Thanks for the extremely useful library, BTW
Your contribution
I probably don't have much time to investigate this in the next couple months, but in the long term it is something I could attempt with some guidance on where to look
The text was updated successfully, but these errors were encountered: