-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quality of inference severely changed after merge_and_unload #1035
Comments
Thanks for reporting this issue and giving all the details. Could you please test something: Could you please produce an output exactly before merging and one output exactly after merging, using the same input and checking if the output is the same? So something like: with torch.no_grad():
input = ...
output_before = model.generate(...)
...
merged_model = model.merge_and_unload()
with torch.no_grad():
input = ...
output_after = merged_model.generate(...)
...
# compare before and after This way, we can check if the merging process itself is responsible for the divergence or if it's something else. |
Hi @BenjaminBossan, thanks for your reply soon and sorry for replying late, I check again as you required, here is the code (Note that the model before merging is derived based on the same method I show above): from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from peft import PeftModel
import torch
# Here I write same process I showed before
model = PeftModel.from_pretrained(base_model, lora_dir, device_map="auto", torch_dtype=torch.bfloat16)
# for name, param in model.named_parameters():
# print(name, param.device)
generation_config = GenerationConfig(
do_sample = True,
temperature = 0.6,
top_p = 0.9,
top_k = 50
)
prompt = "A rainbow is an optical phenomenon that can occur under certain meteorological conditions."
with torch.no_grad():
inputs = tokenizer(prompt, return_tensors="pt")
generation_output = model.generate(
input_ids = inputs.input_ids.to("cuda"),
generation_config = generation_config,
max_new_tokens = 200,
no_repeat_ngram_size = 2
)
res = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print("\nReply of model before merging:: ", res)
# start merging
model = model.merge_and_unload()
with torch.no_grad():
inputs = tokenizer(prompt, return_tensors="pt")
generation_output = model.generate(
input_ids = inputs.input_ids.to("cuda"),
generation_config = generation_config,
max_new_tokens = 200,
no_repeat_ngram_size = 2
)
res = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print("\nReply of merged model: ", res) Here is the result I get, Reply of model before merging: A rainbow is an optical phenomenon that can occur under certain meteorological conditions. It is a light reflection from clouds, which are shaped and colorful. Some rainbows are visible from the ground, and some are invisible. Reply of merged model: A rainbow is an optical phenomenon that can occur under certain meteorological conditions.瘙瘙 Looks like something really goes wrong after merging, I also check the device map of each layer, all model's layer before merging are distributed on 4 gpus I used, no layer is on meta device, do you know what is the problem? I guess potentially it's because I tie the embedding layer and lm_head so merging process didn't perform correctly but I'm not sure about it. Thanks again for your replying. |
I'm also having issues with merging - it looks like the model's outputs are changing pre and post merge. Here's a minimal deterministic (no generation/sampling) example:
Running this is giving me the output:
I've tried this on:
|
Dear all, I have solved this issue by myself and will record the solution here for people who faces the same issue like me. My intuition was right, since I trained the model with tied embedding and there are lora layers beside embedding layer and lm_head, before merging, I need to make the embedding layers untie again so the lora layer beside lm_head can merge with lm_head correctly. After I untie the embedding layers and implement merge_and_unload again, the inference of merged model subsequently becomes normal like the one before merging. Since I didn't touch opt model before, I'm not sure the problem mentioned above is the same problem of mine or not, but I guess the issue of @jordan-benjamin can be solved by following this way as well. Thank you guys for the help and replying. |
@blaze7451 |
System Info
peft==0.5.0
transformers==4.34.0
accelerate==0.21.0
python==3.10.6
Who can help?
No response
Information
Tasks
examples
folderReproduction
Hi all, this is my first time to raise an issue at github repo, so if i do something wrong, welcome to tell me, thx.
I have trained "meta-llama/Llama-2-7b-hf" by qlora with expanded vocabs and tying of embedding layer and lm_head layer. Note that besides qkv layer and mlp, I also employ lora layer on the embedding layer and lm_head. The resulting adapter_model.bin file is around 600 MB and works well when i use it to do the inference.
Here shows the rough code for how i do inference:
As aforementioned, the reply of model looks good and normal at this moment. However, after i use merge_and_unload function to produce the merged file and save it, the inference of new merged model becomes weird and becomes totally chaos.
Here is the rough code I write for merging:
In the beginning I do this under the peft==0.3.~environment and it shows attributeerror: 'embedding' object has no attribute 'bias', but the error message disappear and no error happens again after i update to peft==0.5.0.
I have checked the size of model's bin file and didn't see manifest change between new model and Llama-2-7b-hf, the new model actually very slightly smaller than the original Llama-2. I also print model to check whether the embedding layer size is correct and whether config file is correct.
Model summary (before merging):
After merging, I do the inference with new model but now the outputs becomes something like "prompt脣銝餌����������", looks like maybe the merging process did not implemnt correctly? I saw the comment under #868 so i change the device map to cpu and do the merging again but the result is still the same. Anyone knows the reason and can help me? Thx.
Expected behavior
The model's output should be some normal sentence as it was before merging.
The text was updated successfully, but these errors were encountered: