Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_base_model() is returning the base model with the LoRA still applied. #430

Closed
ClayShoaf opened this issue May 10, 2023 · 17 comments
Closed
Labels
solved solved

Comments

@ClayShoaf
Copy link

Technically, I'm just grabbing the .base_model.model directly, rather than using get_base_model(), but that should have the same effect, since that's all get_base_model() does if the active_peft_config is not PromptLearningConfig as seen here.

After loading a llama model with a LoRA, like so:

shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params)

The PeftModel loads fine and everything is working as expected. However, I can not figure out how to get the original model back without a LoRA still being active when I do an inference.

The code I'm using is from here:

shared.model.disable_adapter()
shared.model = shared.model.base_model.model

This gives me the model back as a LlamaForCausalLM, but when I go to inference, the LoRA is still applied. I made a couple of test LoRAs so that there would be no question as to whether the LoRA is still loaded. They can be found here: https://huggingface.co/clayshoaf/AB-Lora-Test

I am digging around right now, and I see this line: if isinstance(module, LoraLayer): from:

    def _set_adapter_layers(self, enabled=True):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.disable_adapters = False if enabled else True

So I checked in the program and if I load a LoRA and do

[module for module in shared.model.base_model.model.modules() if hasattr(module, "disable_adapters")]

it returns a bunch of modules that are of the type Linear8bitLt (if loaded in 8bit) or Linear4bitLt (if loaded in 4bit).

Would it work to set the modules' disable_adapters value to false? I don't want to hack around too much in the code, because I don't have a deep enough understanding to be sure that I won't mess something else up in the process.

If that won't work, is there something else that I should be doing?

@ClayShoaf
Copy link
Author

This issue still needs to be addressed.

@FartyPants
Copy link

FartyPants commented Jun 13, 2023

I don't think disable_adapter() works... I made an extension for webui and I can switch LORAS fine, but disable_adapter does noting - it doesn't disable the lora - the lora is still applied, even after calling it
I'd look at problem there.
I think this issue has to be probably made as a new issue.

@ClayShoaf
Copy link
Author

I think this issue has to be probably made as a new issue.

Why is that?

@huggingface huggingface deleted a comment from github-actions bot Jun 27, 2023
@pacman100
Copy link
Contributor

shared.model.disable_adapter() disables LoRA because in the forward pass of LoRAlayer it is just returning the original module's output. See the below lines:

https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#L672-L675

You can see the effect of disabling lora in this example notebook: https://github.com/huggingface/peft/blob/main/examples/multi_adapter_examples/PEFT_Multi_LoRA_Inference.ipynb
Screenshot 2023-06-27 at 5 21 46 PM

I don't understand what the issue is? There should be difference in the outputs when enabling and disabling lora adapters

@pacman100 pacman100 added the solved solved label Jun 27, 2023
@fozziethebeat
Copy link

Using with model.disable_adapter(): works but there's still some subtle unwanted behaviors going on.

When loading an adapter, should the original base model be permanently changed as well?

I'm loading adapters with code like this:

config = PeftConfig.from_pretrained(peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    device_map=inf_args.device,
    trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, peft_model_id)

If i do inference with base_model I get different results depending on when I load the Peft model. For example:

print(generate(base_model, query))

model = PeftModel.from_pretrained(base_model, peft_model_id)
print(generate(model, query))
with model.disable_adapter():
    print(generate(base_model, query))
print(generate(base_model, query))

The first base model result matches the base model output within the with model.disable_adapter():.

The first peft model result matches the last base model output (without with ...).

This feels wrong. I would expect to be able to do inference on model and base_model independently without the two affecting each other.

I filed a similar bug in #515 and this doesn't seem to be fixed.

This issue also prevents anyone from loading two Peft models that share the same base model at the same time.

When I add in

model_2 = PeftModel.from_pretrained(base_model, peft_model_id_2)
print(generate(model_2, query))
print(generate(model, query))

I get the same result from both models.

@fozziethebeat
Copy link

Is anyone looking into this? It still feels very unexpected for the base model to be modified when creating a peft model ontop of it.

@fozziethebeat
Copy link

After reading the code very closely, I don't think this is going to get fixed. the only way to access the base model without any adapter replacements is to do

with model.disable_adapter():
    # stuff with base model

When loading LoRA adapter weights, it walks through the base model's modules and swaps out some of them with LoRA replacements. This is why trying to access the base model directly, or even keeping the base model around like in my example, inference always gets the adapters influence. My guess is that this is too big of an architectural change.

My request to the Peft authors however is to be much more explicit that this is how LoRA adapter inference is implemented. I found this behavior very surprising and it's not explicitly written about anywhere I saw.

@fozziethebeat
Copy link

Related, given how baked this architecture is in Peft, I've made a request to vllm to try and implement this in a way that you can load multiple LoRA adapters independently while still being able to access the base model.

@ClayShoaf
Copy link
Author

@oobabooga Something to be aware of, if you aren't already

@BenjaminBossan
Copy link
Member

When loading LoRA adapter weights, it walks through the base model's modules and swaps out some of them with LoRA replacements. This is why trying to access the base model directly, or even keeping the base model around like in my example, inference always gets the adapters influence. My guess is that this is too big of an architectural change.

Yes, this is correct, the base model is mutated when it is converted into a peft model. If you need the original model, at the moment you would have to create a copy of it before passing it to peft. If you have some suggestion where we could update the docs to make this more obvious, please let us know.

@gegallego
Copy link

Hi! The unload() seems to do what you are looking for, no? Maybe get_base_model() should call this method in case the model is a LoraModel?

@fozziethebeat
Copy link

When loading LoRA adapter weights, it walks through the base model's modules and swaps out some of them with LoRA replacements. This is why trying to access the base model directly, or even keeping the base model around like in my example, inference always gets the adapters influence. My guess is that this is too big of an architectural change.

Yes, this is correct, the base model is mutated when it is converted into a peft model. If you need the original model, at the moment you would have to create a copy of it before passing it to peft. If you have some suggestion where we could update the docs to make this more obvious, please let us know.

I'd probably put this in the pydoc for PeftModel.from_pretrained so that it's pretty clearly called out that if model is a PretrainedModel it'll be modified and can only be accessed via unload or disabling the adapter.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Aug 7, 2023
1. Addresses
huggingface#430 (comment)
2. Reword docstring to not be LoRA-specific
BenjaminBossan added a commit that referenced this issue Aug 8, 2023
1. Addresses
#430 (comment)
2. Reword docstring to not be LoRA-specific
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@sunjunlishi
Copy link

@fozziethebeat does peft support vllm speedup

@fozziethebeat
Copy link

I think you mean does vLLM support Loras? I think vLLM now supports running models with Loras but I haven't tried it personally.

@fozziethebeat
Copy link

Also, PediBase's Lorax project is basically the right solution to this problem now.

@sunjunlishi
Copy link

sunjunlishi commented Mar 12, 2024

@fozziethebeat oh yes。
I trained Lora on a quantitative model, and now I want to load them and make their speed faster

model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=False);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved solved
Projects
None yet
Development

No branches or pull requests

7 participants