Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Gemma2 adapter weights lm_head skipped on gguf conversion #9065

Closed
ltoniazzi opened this issue Aug 17, 2024 · 10 comments · Fixed by #9103
Closed

Bug: Gemma2 adapter weights lm_head skipped on gguf conversion #9065

ltoniazzi opened this issue Aug 17, 2024 · 10 comments · Fixed by #9103
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)

Comments

@ltoniazzi
Copy link
Contributor

ltoniazzi commented Aug 17, 2024

What happened?

The lm_head layer for a Gemma2 LoRA adapter is not converted by convert_lora_to_gguf.py, and therefore not applied at inference (ruining performance of the adapter).


How to reproduce:

Expand
  1. LoRA fine-tune Gemma2 with pytorch/peft including lm_head in the target_modules param:
    config = LoraConfig(target_modules=["lm_head"], ...)
  2. Save the adapter.
  3. Convert the adapter debugging
    python convert_lora_to_gguf.py <adapter folder> --base <base model folder> --outtype f32
    then the lm_head layer is skipped by this line in convert_hf_to_gguf.py (and no error is raised):
    if name == "lm_head.weight":
       logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
       return []
  4. Run llama-cli to check that indeed no lora layer is applied in the respective line in llama.cpp:
    ./llama-cli -m base/model/path/Base-F32.gguf \
    --lora lora/model/path/Lora-F32-LoRA.gguf \
    -p "Hello Gemma2" -n 50

Expected behaviour

I think this is a bug because a user might have trained an adapter that is applied to the the lm_head layer, so skipping it on conversion will destroy the adapter's performance. I think the code should either:

  • raise an error saying Cannot convert Gemma2 adapter with lm_head layer

or

  • handle the lm_head layer (although it might be tricky for merging adapters as the lm_head layer shares the weights with the embed layer in Gemma2, probably leading to having to create a new tensor for the lm_head to merge the adapter to).

Comments

  • I think the script convert_lora_to_gguf.py was introduced in PR Refactor lora adapter support #8332, so maybe the @ngxson knows if skipping the lm_head is the desired outcome of if it is actually a bug. Otherwise I'm happy to try figure out why this happens.
  • This is not the case for, say, Phi3, which converts the lm_head lora layer correctly.
  • I can provide more code/models to reproduce the bug easily if that helps.

Name and Version

version: 3524 (bc0f887)
built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.4.0

What operating system are you seeing the problem on?

MacOS, but it should be a platform-independent problem.

Relevant log output

No response

@ltoniazzi ltoniazzi added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels Aug 17, 2024
@ltoniazzi ltoniazzi changed the title Bug: Gemma2 adapter layer lm_head skipped on gguf conversion Bug: Gemma2 adapter weights lm_head skipped on gguf conversion Aug 17, 2024
@josharian
Copy link

I have ~zero context here but I will note that the Gemma family of models uses a reversible embedding, so the lm_head layer is tied to be identical to the embedding layer.

@qnixsynapse
Copy link
Contributor

Gemma 2 uses weight tying so lm_head weights are same as token embedding weights.

@ltoniazzi
Copy link
Contributor Author

Gemma 2 uses weight tying so lm_head weights are same as token embedding weights.

@qnixsynapse @josharian Thanks yes I suspect this might be the reason why these weights are skipped, but I still feel the conversion should not succeed, as the adapter will not work correctly.

@ngxson
Copy link
Collaborator

ngxson commented Aug 18, 2024

Yes that's correct, the adapter should not have lm_head. In fact, we actively ignore this tensor in convert_hf_to_gguf.py

The simple solution for now is to add a check in convert_lora_to_gguf.py, in side the function modify_tensors. We can check if the super() call returns empty for lm_head:

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
dest = super().modify_tensors(data_torch, name, bid)
for dest_name, dest_data in dest:
assert isinstance(dest_data, LoraTorchTensor)
lora_a, lora_b = dest_data.get_lora_A_B()

Add the check:

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
    dest = super().modify_tensors(data_torch, name, bid)
    if name == "lm_head.weight" and len(dest) == 0:
        raise ValueError(f"lm_head is present in adapter, but is ignored in base model")

What do you think? @compilade

@compilade
Copy link
Collaborator

compilade commented Aug 18, 2024

I think it's still applied in the compute graph (where the token embeddings tensor is duplicated for the output), so it should probably not be ignored from LoRA adapters.

llama.cpp/src/llama.cpp

Lines 12020 to 12021 in 2339a0b

// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);

Although this doesn't affect tok_embd.weight (is that a problem?).

Not sure how to bypass the check. Maybe something like "don't call super().modify_tensors for Gemma2 when the tensor name is lm_head.weight", or maybe even generalize this to all model architectures.

Or yes, maybe an error could be appropriate.

@ngxson
Copy link
Collaborator

ngxson commented Aug 18, 2024

I think it's still applied in the compute graph (where the token embeddings tensor is duplicated for the output), so it should probably not be ignored from LoRA adapters.

Yeah, you're right.

The model.output is pointed to the same tensor as model.tok_embd, then seems like create_tensor_for has the ability to create different tensors (with different names), but point to the same memory:

model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading

struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) {

So I think at least we're good on cpp side, as it can handle lora tensors separately even if output and tok_embd are the same.

Although this doesn't affect tok_embd.weight (is that a problem?).

I think that should not be a problem (not 100% sure). I suppose that PEFT see output and tok_embd as 2 different tensors. Maybe need to check this later on @ltoniazzi

Not sure how to bypass the check. Maybe something like "don't call super().modify_tensors for Gemma2 when the tensor name is lm_head.weight", or maybe even generalize this to all model architectures.

The problem was that calling super().modify_tensors on Gemma2 returns 0 tensors, so that make the lm_head.weight being removed in the final gguf lora adapter. We could either:

Option 1: Override the default behavior of Gemma2.modify_tensors and allow it to accept lm_head.weight only if it's converting lora

We can do this by changing adding and not this.is_lora to the if condition below:

# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
# To prevent errors, skip loading lm_head.weight.
if name == "lm_head.weight":
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
return []

Option 2: Or, in my last comment, I suggest to just throw an error if super().modify_tensors returns 0 tensors for lm_head.weight

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
    dest = super().modify_tensors(data_torch, name, bid)
    if name == "lm_head.weight" and len(dest) == 0:
        raise ValueError(f"lm_head is present in adapter, but is ignored in base model")

@ltoniazzi
Copy link
Contributor Author

ltoniazzi commented Aug 18, 2024

@compilade I think currently convert_lora_to_gguf.py will skip the lm_head adapter, then these lines you mention in build_gemma2 will not have find an adapter to apply.

As you mentioned, one could not skip the lm_head adapter during conversion, and then I agree llm_build_lora_mm will apply it correctly to the base output layer.

But I think the tricky part might be merging the adapter (export-lora), because then one needs to save a new tensor for the lm_head as merging the adapter makes the output layer different from the embedding.
And then this new tensor will need to be applied instead of the embedding layer, altering the way build_gemma2 builds the architecture.

So it looks to me that either:

  • raising error on conversion of a lora adapter containing lm_head or
  • allow conversion of lm_head but raise an error for export-lora adapter

are the cleaner options.
And if users show a demand to be able to merge lm_head adapters for Gemma2, then address this further.

@ltoniazzi
Copy link
Contributor Author

ltoniazzi commented Aug 19, 2024

I suppose that PEFT see output and tok_embd as 2 different tensors.

Had a quick look and it looks like for Gemma2 (and probably all models with config.tie_word_embeddings = True) lm_head and tok_embd are the same tensor in PEFT.

It indeed happens that merging the lm_head adapter merges it to the embed layer as well huggingface/peft#2018.

Also at inference, if the adapter is in unmerged state (lora_layer.merged = False) then whether lm_head and tok_embd are the same should not matter, as the forward pass performs the adapter multiplication separately and then adds it to the base as in llm_build_lora_mm.

@ngxson
Copy link
Collaborator

ngxson commented Aug 19, 2024

Hmm interesting. The problem with llama.cpp is that we currently don't support lora for tok_embd. Although it's simple to fix, I'm not sure if it's worth, because I never see anyone fine tuning a LoRA with tok_embd.

But I think the tricky part might be merging the adapter (export-lora), because then one needs to save a new tensor for the lm_head as merging the adapter makes the output layer different from the embedding.

We could, for example, detect if the model is missing lm_head, and duplicate tok_embd if one is missing. But IMO that's too many small details, given that lora adapters for gemma/gemma2 is not very popular atm.

Probably we should go with the simple way for now: don't allow user to convert such adapter in the first place. Then, we will fix it when more users use lora adapters with gemma.

@ltoniazzi
Copy link
Contributor Author

Probably we should go with the simple way for now: don't allow user to convert such adapter in the first place

Agree, I can have a go at it later this week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants