Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn if using tied target module with tie_word_embeddings #2025

Merged
merged 29 commits into from
Aug 29, 2024

Conversation

ltoniazzi
Copy link
Contributor

@ltoniazzi ltoniazzi commented Aug 20, 2024

Context

Solving issue #2018.

  • Raise a warning if the user is requesting an output target_module when the embeddings are tied, because this could lead to errors, for example when merging the adapter.
  • Also refactored the code to get the model config

Todo

  • Try if load with tie_word_embeddings=False is an actual option. Load Gemma2 with finetuned different lm_weights and check that the lm_head is not replaced with the embedding (even if cloned). If it works, try to merge an adapter to lm_weight and then load it to check if embed and lm_head are kept separate. (the main concern is that the loading model's architecture might ignore any lm_head weight present in safetensors, as it happens in llama.cpp for example).
  • In the end I only checked how to save the base model as untied, then the user can work from there, although I DID NOT DO A THOROUGH CHECK OF THIS APPROACH, only checked this:
from transformers import AutoModelForCausalLM
import torch

# Load original tied model
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data)

# Save the untied model
untied_model_dir = "tmp_model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)

assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data)
assert model.model.embed_tokens.weight.data.data_ptr() != model.lm_head.weight.data.data_ptr()
  • Add in warning about porting to other formats.

src/peft/mapping.py Outdated Show resolved Hide resolved
@ltoniazzi ltoniazzi marked this pull request as ready for review August 20, 2024 14:02
src/peft/mapping.py Outdated Show resolved Hide resolved
src/peft/mapping.py Outdated Show resolved Hide resolved
src/peft/mapping.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this PR. I think we need to rethink the approach here, as the current one will not work in all situations.

  1. get_peft_model is a very generic function and is also used for prompt tuning methods, for instance. Therefore, we cannot assume that peft_config.target_modules exist.
  2. Not all methods allow to merge the weights, thus we should not warn in those cases (false warnings should be avoided as much as possible).
  3. Even if peft_config.target_modules does exist, it could be a string, so looping over it will not always be correct.
  4. As we already observed, it will not work for custom models with tied weights, but let's consider this out of scope for now.

So how can we correctly identify when a warning is needed? My proposal is that this needs to be solved on a different level:

The check if there is a tied target layer needs to live on the corresponding method's model level (e.g. LoraModel), as only there can we really know which layers are targeted. Thankfully, the models that support merging all inherit from BaseTuner. There, we have the inject_adapter method. If you look at this line, you can see that all modules that are actually targeted are stored in self.targeted_module_names. Therefore, after exiting the loop, we can add a new method that takes this list and checks if any of the keys are tied weights using the logic you proposed.

This new check should be implemented as a new method on the BaseTuner class, so that subclasses such as LoraModel may choose to override the method if there ever is a need.

Additionally, I wonder if there should be a warning when the user attempts to merge. One could argue that this is too late, but even at this point, there are workarounds: If the user clones the tied weights, they can merge without affecting the other weight (at the cost of extra memory).

This additional warning could be added to the _check_merge_allowed method and it could re-use the same method as mentioned above to perform the check. However, the warning message should be a bit different.

I know this is all a bit more complicated that initially thought and not necessarily what you "signed up for". So let me know if you still want to work on this or not, in which case I'll put this on my backlog.

@ltoniazzi
Copy link
Contributor Author

I know this is all a bit more complicated that initially thought and not necessarily what you "signed up for".

Not at all thanks, sounds really good, I'll have a go!

@ltoniazzi ltoniazzi marked this pull request as draft August 21, 2024 10:42
@BenjaminBossan
Copy link
Member

Not at all thanks, sounds really good, I'll have a go!

Thanks a lot.

@ltoniazzi ltoniazzi force-pushed the bug/warn-if-tied-embedding branch from c236129 to 44a02de Compare August 21, 2024 12:51
@ltoniazzi
Copy link
Contributor Author

ltoniazzi commented Aug 21, 2024

@BenjaminBossan I made a version addressing your suggestions. Also, I refactored getting the model config in the code base.

However, the warning message should be a bit different.

I feel like the new message can be the same. Let me know.

(I can't run the whole test suite as I do not have a cuda-compatible gpu.)

@ltoniazzi ltoniazzi marked this pull request as ready for review August 21, 2024 14:00
@ltoniazzi ltoniazzi changed the title Add warning if using output target module whith tied embeddings Add warning if using output target module with tied embeddings Aug 21, 2024
@ltoniazzi ltoniazzi changed the title Add warning if using output target module with tied embeddings Warn if using output target module with tied embeddings Aug 21, 2024
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR.

I feel like the new message can be the same. Let me know.

I think the error message is good as is for when the model is being initialized. When merging, I think we could show a different warning, where we mention that if the weight is cloned beforehand, merging should work, at the cost of higher memory usage.

To implement this, I would change the _warn_if_tied_embeddings_in_target_modules method from warning to just performing the check and returning a bool (renaming the method accordingly). Then during injection, if the check returns True, the current warning is given, and during merging, if the check returns True, the adapted warning is given. WDYT?

(I can't run the whole test suite as I do not have a cuda-compatible gpu.)

This is fine.

src/peft/tuners/tuners_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. I have a few small suggestions for improvements, please check them out?

It would also be great to add unit tests for this, but probably this will be a bit more complicated. I leave it up to you if you want to give this a try, otherwise I'll work on it in a subsequent PR.

src/peft/tuners/tuners_utils.py Outdated Show resolved Hide resolved
src/peft/tuners/tuners_utils.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/model.py Outdated Show resolved Hide resolved
src/peft/tuners/tuners_utils.py Outdated Show resolved Hide resolved
@ltoniazzi ltoniazzi force-pushed the bug/warn-if-tied-embedding branch from 516fc3c to 3a51e67 Compare August 24, 2024 08:53
@ltoniazzi
Copy link
Contributor Author

ltoniazzi commented Aug 24, 2024

Sure very happy to write tests! I'll put them in tests/test_tuners_utils.py.

Just one question: to mock models with tied embeddings, should I use the test model "HuggingFaceH4/tiny-random-LlamaForCausalLM" but loaded with:

model = AutoModelForCausalLM.from_pretrained(model_id, tie_word_embeddings=True)

@ltoniazzi ltoniazzi force-pushed the bug/warn-if-tied-embedding branch from cd3e830 to cf4bf3e Compare August 25, 2024 09:00
@ltoniazzi ltoniazzi marked this pull request as draft August 25, 2024 09:02
@ltoniazzi ltoniazzi marked this pull request as ready for review August 25, 2024 10:35
@ltoniazzi
Copy link
Contributor Author

@BenjaminBossan I added the test here tests/test_tuners_utils.py 👍

tests/test_tuners_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for making the updates, using DUMMY_MODEL_CONFIG consistently and extending the tests. This looks quite good already, but I have some suggestions for improvements, please chekc.

I just did this as it's a bit unclear if in this case the model_config needs to default to None or if it can be the DUMMY one, let me know!

The change you made looks good as is.

Just one question: to mock models with tied embeddings, should I use the test model "HuggingFaceH4/tiny-random-LlamaForCausalLM" but loaded with:

I didn't know that this was an option. Yes, looks like the right choice.

warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications when merging the adapter. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings), "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I didn't know about the option to pass tie_word_embeddings=False. Is there even a need to clone the weights in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it works, I added in the warning code to create the untied model.

config = BaseTuner.get_model_config(ModelWithNoConfig())
assert config == DUMMY_MODEL_CONFIG

def test_warn_for_tied_embeddings_inject_and_merge(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding these tests. They are already looking quite good. I think, however, that this last test can be simplified a bit.

As you correctly observed, there are 6 scenarios to test:

  • Warning for get_peft_model and warning for merging.
  • Valid warning vs no tied embeddings vs tied embeddings but not targeted.

Instead of cramming those into a single test, let's make this 6 separate tests. It should also be fine to make it 3 tests, where get_peft_model and merging are checked together. Hopefully, this should make the assert_warning_triggered function unnecessary.

You probably also had a bit of an issue that unrelated warnings could be recorded. Maybe this can be made simpler by using the recwarn fixture. Then you can just check that any warning has been recorded with the corresponding message, something like:

assert any(str(warning.message).startswith(msg) for warning in recwarn.list)

pass


class TestBaseTunerMethods(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's split this test class into 2: One for get_model_config and one for the tied embeddings.

@ltoniazzi ltoniazzi force-pushed the bug/warn-if-tied-embedding branch from a2f7354 to 7926888 Compare August 26, 2024 12:48
@ltoniazzi ltoniazzi marked this pull request as draft August 26, 2024 12:48
@ltoniazzi ltoniazzi marked this pull request as ready for review August 27, 2024 12:47
@ltoniazzi
Copy link
Contributor Author

@BenjaminBossan I think I've address the comments 👍

@BenjaminBossan
Copy link
Member

Thanks for the latest updates. I only have one more question, namely when it comes to how to untie the weights. In the script you provide, you clone the weights but is that even necessary if tie_word_embeddings=False is passed? To give an example, I tried this (I changed the model but only because I had already downloaded it):

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", tie_word_embeddings=False)
>>> model.model.embed_tokens.weight.data_ptr()
126062054867008
>>> model.lm_head.weight.data_ptr()  # <= different data ptr
126051845931072
>>> model.model.embed_tokens.weight.sum()
tensor(952564.6250, grad_fn=<SumBackward0>)
>>> model.lm_head.weight.sum()
tensor(255.3427, grad_fn=<SumBackward0>)

>>> from peft import LoraConfig, get_peft_model
>>> config = LoraConfig(init_lora_weights=False, target_modules=["embed_tokens"])
>>> model = get_peft_model(model, config)
>>> unloaded = model.merge_and_unload()
>>> unloaded.model.embed_tokens.weight.sum()  # <= embed weights changed
tensor(985655.8125)
>>> unloaded.lm_head.weight.sum()  # <= lm head stayed the same
tensor(255.3427)

@ltoniazzi
Copy link
Contributor Author

In the script you provide, you clone the weights but is that even necessary if tie_word_embeddings=False is passed? To give an example, I tried this (I changed the model but only because I had already downloaded it):

Yes I agree with your script but the user wants to fix lm_head after loading. Because when loading with tie_word_embeddings=False the lm_head is randomly initialized, so the user wants to set it to equal the embed layer, which means they want a clone of it.

This cloning also seems to allow to save it correctly. If you do not clone (beside actaully re-tieing the embeddings), then when you load the saved-untied model the last assertion below will fail, otherwise, if you clone, it will pass:

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data)

# Save the untied model
untied_model_dir = "tmp_model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)

assert model.model.embed_tokens.weight.data.data_ptr() != model.lm_head.weight.data.data_ptr()
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data)

@BenjaminBossan
Copy link
Member

Because when loading with tie_word_embeddings=False the lm_head is randomly initialized, so the user wants to set it to equal the embed layer, which means they want a clone of it.

Oh wow, I did not know that the LM head will be randomly initialized, that's quite surprising IMO. I would have expected to get the same parameter values, just not tied. Thanks for making me aware of that.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ltoniazzi
Copy link
Contributor Author

Not sure how to reproduce the error in the git actions

ruff check src tests examples docs scripts docker
All checks passed!
ruff format --check src tests examples docs scripts docker
189 files already formatted
doc-builder style src/peft tests docs/source --max_len 119 --check_only
Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.8.18/x64/bin/doc-builder", line 8, in <module>
    sys.exit(main())
  File "/opt/hostedtoolcache/Python/3.8.[18](https://github.com/huggingface/peft/actions/runs/10578337704/job/29356715757?pr=2025#step:5:19)/x64/lib/python3.8/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
    args.func(args)
  File "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/doc_builder/commands/style.py", line 28, in style_command
    raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 1 files should be restyled!
make: *** [Makefile:11: quality] Error 1
Error: Process completed with exit code 2.

@BenjaminBossan
Copy link
Member

@ltoniazzi could you please run make style to make the linter happy? The ruff version being used is 0.6.2.

@BenjaminBossan
Copy link
Member

I ran make style locally and this is the diff I get:

@@ -530,8 +530,8 @@ model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
     @staticmethod
     def get_model_config(model: nn.Module) -> dict:
         """
-        This method gets the config from a model in dictionary form.
-        If model has not attribute config, then this method returns a default config.
+        This method gets the config from a model in dictionary form. If model has not attribute config, then this
+        method returns a default config.

@ltoniazzi
Copy link
Contributor Author

could you please run make style

Done!

@ltoniazzi ltoniazzi changed the title Warn if using output target module with tied embeddings Warn if using tied target module with tie_word_embeddings Aug 28, 2024
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, very nicely done PR. I just have two tiny comments for cosmetic reasons, otherwise this can be merged.

)
return model

def _is_warn_triggered(self, rrecwarn, endswith):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you call it rrecwarn to avoid naming conflicts? If yes, how about just passing the recwarn.list, which is all we need, and call it warning_list or so.

# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
```
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why you left-aligned the code snippet so that it is nicely printed. But this is really an eye-sore to read in code. Here is a trick to that let's us use the correct indentation but still get a nice warning message by using textwrap.dedent:

            example_code = textwrap.dedent(
                """
                ```python
                from transformers import AutoModelForCausalLM

                # Load original tied model
                model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)

                # Set the randomly initialized lm_head to the previously tied embeddings
                model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()

                # Save the untied model
                untied_model_dir = "dir/for/untied/model"
                model.save_pretrained(untied_model_dir)
                model.config.save_pretrained(untied_model_dir)

                # Now use the original model but in untied format
                model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
                ```
                """
            )
            warnings.warn(
                f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
                "This can lead to complications. "
                "You can opt to merge the adapter after cloning the weights (to untie the embeddings). "
                "You can untie the embeddings by loading the model with `tie_word_embeddings=False`. For example:"
                + example_code
            )

The textwrap module is from the standardlib and needs to be imported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed both thanks!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much, great work, hopefully this will help users in the future to avoid this potential pitfall.

@BenjaminBossan BenjaminBossan merged commit 679bcd8 into huggingface:main Aug 29, 2024
14 checks passed
@ltoniazzi
Copy link
Contributor Author

@BenjaminBossan Thanks so much for your help! ❤️

Btw, a test on main failed, do you think it's related to this PR?

@BenjaminBossan
Copy link
Member

Btw, a test on main failed, do you think it's related to this PR?

Don't worry, this is a known issue with X-LoRA that came about with a recent change in transformers.

@ltoniazzi ltoniazzi deleted the bug/warn-if-tied-embedding branch August 29, 2024 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants