Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lm_head layer Problem in gemma2 : 2b-it #2244

Closed
2 of 4 tasks
OmarHelwe10 opened this issue Dec 2, 2024 · 5 comments
Closed
2 of 4 tasks

Lm_head layer Problem in gemma2 : 2b-it #2244

OmarHelwe10 opened this issue Dec 2, 2024 · 5 comments

Comments

@OmarHelwe10
Copy link

OmarHelwe10 commented Dec 2, 2024

System Info

all libraries are latest versions
python 3.11

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

def get_linear_layers(model):
    """Extract linear layers from the model for LoRA configuration."""
    model_modules = str(model.modules)
    pattern = r'\((\w+)\): Linear'
    return list(set(re.findall(pattern, model_modules)))
    
     def _apply_lora(self, model):
        """Apply LoRA configuration to the model."""
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            target_modules=get_linear_layers(model),
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)

when i run the ft script the first warning is this :

UserWarning: Model with tie_word_embeddings=True and the tied_target_modules=['lm_head'] are part of the adapter. This can lead to complications, for example when merging the adapter or converting your model to formats other than safetensors. See for example #2018.

but it works fine and i get adapters in which i can run later using Peftmodel.from_pretrained(model,adapter)...

but when i want to run convert adapters to gguf so i merge them with the LLm to get a final version of gguf model that i can run for inference i run the following script using llama.cpp:

python convert_lora_to_gguf.py --base "/models/gemma-2-2b-it" --outfile "/gemma-2b/adapters.gguf" --outtype auto "/ft_models/gemma-2b/adapters"

INFO:lora-to-gguf:Loading base model: gemma-2-2b-it
INFO:hf-to-gguf:choosing --outtype bf16 from first tensor type (torch.float32)
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:lora-to-gguf:Exporting model...
Traceback (most recent call last):
  File "/home/ec2-user/abdullah/projects/llama.cpp/convert_lora_to_gguf.py", line 432, in <module>
    model_instance.write()
  File "/home/ec2-user/abdullah/projects/llama.cpp/convert_hf_to_gguf.py", line 434, in write
    self.prepare_tensors()
  File "/home/ec2-user/abdullah/projects/llama.cpp/convert_hf_to_gguf.py", line 298, in prepare_tensors
    for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/abdullah/projects/llama.cpp/convert_lora_to_gguf.py", line 408, in modify_tensors
    raise ValueError("lm_head is present in adapter, but is ignored in base model")
ValueError: lm_head is present in adapter, but is ignored in base model

Expected behavior

what should i do for this kind of issue?

@BenjaminBossan
Copy link
Member

When you do target_modules=get_linear_layers(model), you apply LoRA to the LM head as well, whose weights are tied to the embedding layer weights. This is most likely not what you want. To target all linear layers except the LM head, use target_modules="all-linear".

@OmarHelwe10
Copy link
Author

but i am noticing that when i target the lm_head as well the accuracy of the fine tuned tasks gets higher by almost 10%

@BenjaminBossan
Copy link
Member

What do you mean by accuracy, are you working on a sequence classification task? Also, there are probably better ways to achieve that accuracy, maybe tinker a bit with the hyper params like increase the LoRA r.

If you absolutely want to target the LM head, you can create a copy of it first (after loading the base model, before applying PEFT and merging):

model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()

@ltoniazzi
Copy link
Contributor

ltoniazzi commented Dec 3, 2024

@OmarHelwe10 If the embeddings are tied, lm_head is the embedding layer, so you can train both without altering the architecture by training only the embedding (which lm_head points to).

Gguf conversion currently does not allow to have lm_head different from embedding as this alters the architecture of gemma and it’s not supported by llamacpp at the moment.

So, if you want to use gguf conversion with a tied embedding model then do not train lm_head.

(If there is a strong case from the community to have gemma version with embeddings not tied then this might change in llamacpp in the future.)

if you do not want to use gguf then you can follow this approach in #2025 to train lm_head.

@BenjaminBossan
Copy link
Member

Gguf conversion currently does not allow to have lm_head different from embedding as this alters the architecture of gemma and it’s not supported by llamacpp at the moment.

Thanks for the info, I didn't know that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants