Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quality of inference severely changed after merge_and_unload #1035

Closed
2 of 4 tasks
blaze7451 opened this issue Oct 18, 2023 · 5 comments
Closed
2 of 4 tasks

Quality of inference severely changed after merge_and_unload #1035

blaze7451 opened this issue Oct 18, 2023 · 5 comments

Comments

@blaze7451
Copy link

blaze7451 commented Oct 18, 2023

System Info

peft==0.5.0
transformers==4.34.0
accelerate==0.21.0
python==3.10.6

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Hi all, this is my first time to raise an issue at github repo, so if i do something wrong, welcome to tell me, thx.

I have trained "meta-llama/Llama-2-7b-hf" by qlora with expanded vocabs and tying of embedding layer and lm_head layer. Note that besides qkv layer and mlp, I also employ lora layer on the embedding layer and lm_head. The resulting adapter_model.bin file is around 600 MB and works well when i use it to do the inference.

Here shows the rough code for how i do inference:

from transformers import  LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel
import torch

tokenizer = LlamaTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
base_model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map=device_map, torch_dtype=torch.bfloat16)
base_model.config.tie_word_embeddings = True
base_model.resize_token_embeddings(len(tokenizer))

embedding = torch.load(pt file storing some embedding vectors to initialize embedding layer for expanded vocabs)
weight = base_model.model.embed_tokens.weight.data
device = base_model.model.embed_tokens.weight.data.device
dtype = base_model.model.embed_tokens.weight.data.dtype
length = len(embedding)

for i in range(length):
     weight[32000+i] = embedding[i].type(dtype).to(device)

model = PeftModel.from_pretrained(base_model, lora_dir, device_map="auto", torch_dtype=torch.bfloat16)

prompt = "some text"

with torch.no_grad():
     inputs = tokenizer(prompt, return_tensors="pt")
     generation_output = model.generate(
     input_ids = inputs.input_ids.to("cuda"),
     max_new_tokens = max_new_tokens,
     no_repeat_ngram_size = no_repeat_ngram_size
            )
     res = tokenizer.decode(generation_output[0], skip_special_tokens=True)
     print("\nReply: ", res)

As aforementioned, the reply of model looks good and normal at this moment. However, after i use merge_and_unload function to produce the merged file and save it, the inference of new merged model becomes weird and becomes totally chaos.
Here is the rough code I write for merging:

from transformers import  LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel
import torch

tokenizer = LlamaTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
base_model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map=device_map, torch_dtype=torch.bfloat16)
base_model.config.tie_word_embeddings = True
base_model.resize_token_embeddings(len(tokenizer))

embedding = torch.load(pt file storing some embedding vectors to initialize embedding layer for expanded vocabs)
weight = base_model.model.embed_tokens.weight.data
device = base_model.model.embed_tokens.weight.data.device
dtype = base_model.model.embed_tokens.weight.data.dtype
length = len(embedding)

for i in range(length):
     weight[32000+i] = embedding[i].type(dtype).to(device)

model = PeftModel.from_pretrained(base_model, lora_dir, device_map="auto", torch_dtype=torch.bfloat16)
merged_model = model.merge_and_unload()
merged_model.save_pretrained(output_dir)

In the beginning I do this under the peft==0.3.~environment and it shows attributeerror: 'embedding' object has no attribute 'bias', but the error message disappear and no error happens again after i update to peft==0.5.0.

I have checked the size of model's bin file and didn't see manifest change between new model and Llama-2-7b-hf, the new model actually very slightly smaller than the original Llama-2. I also print model to check whether the embedding layer size is correct and whether config file is correct.

Model summary (before merging):

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(
          57621, 4096
          (lora_dropout): ModuleDict(
            (default): Dropout(p=0.05, inplace=False)
          )
          (lora_A): ModuleDict()
          (lora_B): ModuleDict()
          (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 64x59241 (GPU 0)])
          (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 4096x64 (GPU 0)])
        )
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (v_proj): Linear(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (o_proj): Linear(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(
                in_features=4096, out_features=11008, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=11008, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (up_proj): Linear(
                in_features=4096, out_features=11008, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=11008, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (down_proj): Linear(
                in_features=11008, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=11008, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (act_fn): SiLUActivation()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
        (norm): LlamaRMSNorm()
      )
      (lm_head): Linear(
        in_features=4096, out_features=59241, bias=False
        (lora_dropout): ModuleDict(
          (default): Dropout(p=0.05, inplace=False)
        )
        (lora_A): ModuleDict(
          (default): Linear(in_features=4096, out_features=64, bias=False)
        )
        (lora_B): ModuleDict(
          (default): Linear(in_features=64, out_features=57621, bias=False)
        )
        (lora_embedding_A): ParameterDict()
        (lora_embedding_B): ParameterDict()
      )
    )
  )
)

After merging, I do the inference with new model but now the outputs becomes something like "prompt脣銝餌����������", looks like maybe the merging process did not implemnt correctly? I saw the comment under #868 so i change the device map to cpu and do the merging again but the result is still the same. Anyone knows the reason and can help me? Thx.

Expected behavior

The model's output should be some normal sentence as it was before merging.

@BenjaminBossan
Copy link
Member

Thanks for reporting this issue and giving all the details. Could you please test something:

Could you please produce an output exactly before merging and one output exactly after merging, using the same input and checking if the output is the same? So something like:

with torch.no_grad():
    input = ...
    output_before = model.generate(...)
    ...
merged_model = model.merge_and_unload()
with torch.no_grad():
    input = ...
    output_after = merged_model.generate(...)
    ...
# compare before and after

This way, we can check if the merging process itself is responsible for the divergence or if it's something else.

@blaze7451
Copy link
Author

blaze7451 commented Oct 19, 2023

Hi @BenjaminBossan, thanks for your reply soon and sorry for replying late,

I check again as you required, here is the code (Note that the model before merging is derived based on the same method I show above):

from transformers import  LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from peft import PeftModel
import torch

# Here I write same process I showed before
model = PeftModel.from_pretrained(base_model, lora_dir, device_map="auto", torch_dtype=torch.bfloat16)

# for name, param in model.named_parameters():
#     print(name, param.device)

generation_config = GenerationConfig(
            do_sample = True,
            temperature = 0.6,
            top_p = 0.9,
            top_k = 50
        )

prompt = "A rainbow is an optical phenomenon that can occur under certain meteorological conditions."

with torch.no_grad():
     inputs = tokenizer(prompt, return_tensors="pt")
     generation_output = model.generate(
     input_ids = inputs.input_ids.to("cuda"),
     generation_config = generation_config,
     max_new_tokens = 200,
     no_repeat_ngram_size = 2
        )

     res = tokenizer.decode(generation_output[0], skip_special_tokens=True)
     print("\nReply of model before merging:: ", res)

# start merging
model = model.merge_and_unload()

with torch.no_grad():
     inputs = tokenizer(prompt, return_tensors="pt")
     generation_output = model.generate(
     input_ids = inputs.input_ids.to("cuda"),
     generation_config = generation_config,
     max_new_tokens = 200,
     no_repeat_ngram_size = 2
        )

     res = tokenizer.decode(generation_output[0], skip_special_tokens=True)
     print("\nReply of merged model: ", res)

Here is the result I get,

Reply of model before merging: A rainbow is an optical phenomenon that can occur under certain meteorological conditions. It is a light reflection from clouds, which are shaped and colorful. Some rainbows are visible from the ground, and some are invisible.
There are many different ways to capture a rain bird, but the most popular technique is to use a camera and a long-term, as shown in the picture above. This technique will allow you to take a picture of the rainbird and also a small piece of a snowfall. You can also use other methods such as using a microphone, a solar lamp or a flashlight. In the photo above, you can see that the sun is shining through the clouds and the shadows are scattered. The shots are taken with a standard camera with the position of camera up and with an adjustment of 10 to 15 m. A few other techniques are using the camera in a position in which the water is in full light, or using some other lighting device such like a sun-light or an
electric electric lamp. If

Reply of merged model: A rainbow is an optical phenomenon that can occur under certain meteorological conditions.瘙瘙
獬獬巂巂祏祏巂鴯鴯鶓鶓鴯巂惻惻巂壼壼巂睺睺巂獬瘙螫螫��窰窰������������������������������������螫巂澮澮
巂藶巂菰菰巂鶓巂鎣鎣巂巰巰巂螫睺���������������諴諴����������������������獬����������������������窰嘖嘖窰��
睺螫惻鴯繒繒鴯玕玕鴯桫

Looks like something really goes wrong after merging, I also check the device map of each layer, all model's layer before merging are distributed on 4 gpus I used, no layer is on meta device, do you know what is the problem? I guess potentially it's because I tie the embedding layer and lm_head so merging process didn't perform correctly but I'm not sure about it. Thanks again for your replying.

@jordan-benjamin
Copy link

I'm also having issues with merging - it looks like the model's outputs are changing pre and post merge. Here's a minimal deterministic (no generation/sampling) example:

from transformers import AutoModelForCausalLM, PreTrainedModel
import torch
from peft import (
    LoraConfig,
    get_peft_model,
)

model_name = "facebook/opt-125m"

model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    torch_dtype=torch.float32,
)
model.requires_grad_(False)


@torch.no_grad()
def model_out(model):
    input_ids = torch.arange(10, device=model.device, dtype=torch.long).unsqueeze(0)
    output_logits = model(input_ids).logits
    return output_logits


def relative_diff(a, b):
    diff = torch.abs(a - b) / (torch.abs(a) + 1e-2)
    return diff.mean(), diff.max()


def absolute_diff(a, b):
    diff = torch.abs(a - b)
    return diff.mean(), diff.max()


def trainable_params(model):
    return [p for p in model.parameters() if p.requires_grad]


@torch.no_grad()
def init_trainable_params(model):
    ps = trainable_params(model)
    for p in ps:
        p[:] = 3


lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules=[
        "fc2",
    ],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
    # layers_to_transform=4
)

pmodel = get_peft_model(model, lora_config)

# re-init the lora params so that they don't multiply
# to 0 and thus have an ffect on the output
init_trainable_params(pmodel)

plogits = model_out(pmodel)

unloaded = pmodel.merge_and_unload()

ulogits = model_out(unloaded)

print("rdiff", relative_diff(plogits, ulogits))
print("adiff", absolute_diff(plogits, ulogits))

Running this is giving me the output:

rdiff (tensor(0.1405), tensor(110.4522))
adiff (tensor(0.2044), tensor(2.6414))

I've tried this on:

peft==0.4.0 (also tried on latest github master with the same result)
transformers==4.34.1
accelerate==0.23.0

@blaze7451
Copy link
Author

Dear all, I have solved this issue by myself and will record the solution here for people who faces the same issue like me.

My intuition was right, since I trained the model with tied embedding and there are lora layers beside embedding layer and lm_head, before merging, I need to make the embedding layers untie again so the lora layer beside lm_head can merge with lm_head correctly. After I untie the embedding layers and implement merge_and_unload again, the inference of merged model subsequently becomes normal like the one before merging. Since I didn't touch opt model before, I'm not sure the problem mentioned above is the same problem of mine or not, but I guess the issue of @jordan-benjamin can be solved by following this way as well. Thank you guys for the help and replying.

@ajgallard
Copy link

@blaze7451
I am currently experiencing this issue and would like to ask if you could possibly share the change in code you made to resolve the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants