Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn if using tied target module with tie_word_embeddings #2025

Merged
merged 29 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b610bbc
Add warning if using output target module whith tied embeddings
ltoniazzi Aug 20, 2024
dbd1eab
Add embed
ltoniazzi Aug 20, 2024
b0d3e61
Use EMBEDDING_LAYER_NAMES
ltoniazzi Aug 20, 2024
c8d6236
Update src/peft/mapping.py
ltoniazzi Aug 20, 2024
4e24774
Remove line
ltoniazzi Aug 20, 2024
39c9c4b
Update src/peft/mapping.py
ltoniazzi Aug 20, 2024
388b3f4
Move warn in BaseTuner
ltoniazzi Aug 21, 2024
b5fa3ce
Rename BaseTuner in mapping
ltoniazzi Aug 21, 2024
f36f3af
Modify warning message
ltoniazzi Aug 21, 2024
acde170
Style
ltoniazzi Aug 21, 2024
06f0d00
Precommit
ltoniazzi Aug 21, 2024
d9b2b53
Docstr for get_model_config and separate warnings
ltoniazzi Aug 22, 2024
fe5c86f
Reword
ltoniazzi Aug 22, 2024
5ba0e9f
Typo
ltoniazzi Aug 22, 2024
775a325
Put warn in base class - force dummy config
ltoniazzi Aug 24, 2024
fc0541e
Fix bug from rebase
ltoniazzi Aug 24, 2024
a1e5c17
Add get model config tests
ltoniazzi Aug 25, 2024
0861e1e
Add warning test for merging/loading
ltoniazzi Aug 25, 2024
82cab47
Refactor warn test
ltoniazzi Aug 25, 2024
667c033
Refactor warn test
ltoniazzi Aug 25, 2024
5c6c456
Refactor warn test
ltoniazzi Aug 25, 2024
12c73b5
Update tests/test_tuners_utils.py
ltoniazzi Aug 25, 2024
7f290ea
Update merge warning message
ltoniazzi Aug 26, 2024
7926888
Decouple tests
ltoniazzi Aug 26, 2024
6d360a9
Add instructions to save an untied model
Aug 27, 2024
964fdae
Add instructions to save an untied model
Aug 27, 2024
89928ca
Add comment on different format
Aug 27, 2024
809f689
Run make style
Aug 28, 2024
37a56eb
Address comments
ltoniazzi Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
VeraModel,
XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner as _BaseTuner
from .tuners.tuners_utils import BaseTuner
from .utils import _prepare_prompt_learning_config


Expand Down Expand Up @@ -103,7 +103,7 @@
"HRA": HRAConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
"LORA": LoraModel,
"LOHA": LoHaModel,
"LOKR": LoKrModel,
Expand Down Expand Up @@ -159,13 +159,11 @@ def get_peft_model(
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
"""
model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()

model_config = BaseTuner.get_model_config(model)
old_name = peft_config.base_model_name_or_path
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
Expand Down
4 changes: 3 additions & 1 deletion src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torch import nn
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG

from .config import PeftConfig
from .peft_model import PeftModel
from .tuners import (
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str
self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name)
self.set_modules_to_save(peft_config, adapter_name)

self.config = getattr(model, "config", {"model_type": "custom"})
self.config = getattr(model, "config", DUMMY_MODEL_CONFIG)

# the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
# numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
Expand Down
7 changes: 4 additions & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG

from . import __version__
from .config import PeftConfig
from .tuners import (
Expand Down Expand Up @@ -1231,9 +1233,8 @@ def create_or_update_model_card(self, output_dir: str):

card.data["library_name"] = "peft"

model_config = getattr(self, "config", None)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = BaseTuner.get_model_config(self)
model_config = None if model_config == DUMMY_MODEL_CONFIG else model_config
if model_config is not None and "_name_or_path" in model_config:
card.data["base_model"] = model_config["_name_or_path"]

Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def _check_merge_allowed(self):

Currently gptq quantization and replicated layers do not support merging.
"""
super()._check_merge_allowed()
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
Expand Down
73 changes: 68 additions & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import re
import textwrap
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
Expand All @@ -30,7 +31,7 @@
from transformers.pytorch_utils import Conv1D

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES, SEQ_CLS_HEAD_NAMES
from peft.utils.constants import DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES, SEQ_CLS_HEAD_NAMES
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
Expand Down Expand Up @@ -361,7 +362,36 @@ def _check_merge_allowed(self):

Raise a ValueError if it is not possible to merge the adapter with the given configuration.
"""
pass
example_code = textwrap.dedent(
"""
```python
from transformers import AutoModelForCausalLM

# Load original tied model
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)

# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()

# Save the untied model
untied_model_dir = "dir/for/untied/model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)

# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
```
"""
)
tied_target_modules = self._get_tied_target_modules(self.model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings). "
"You can untie the embeddings by loading the model with `tie_word_embeddings=False`. For example:"
+ example_code
)

def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
r"""
Expand All @@ -387,9 +417,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False

model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = self.get_model_config(model)

peft_config = self._prepare_adapter_config(peft_config, model_config)

Expand Down Expand Up @@ -430,6 +458,15 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
parent, target, target_name = _get_submodules(model, key)
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

tied_target_modules = self._get_tied_target_modules(model=model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"See for example https://github.com/huggingface/peft/issues/2018."
)

# Handle X-LoRA case.
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
Expand Down Expand Up @@ -494,6 +531,32 @@ def _unloading_checks(self, adapter_names: Optional[list[str]]):
if is_modules_to_save_available and len(adapters_to_consider) > 1:
raise ValueError("Cannot unload multiple adapters that specify `modules_to_save`.")

@staticmethod
def get_model_config(model: nn.Module) -> dict:
"""
This method gets the config from a model in dictionary form. If model has not attribute config, then this
method returns a default config.

Args:
model (`nn.Module`):
Model to get the config from.
default (`dict|None`, *optional*)::
What to return if model does not have a config attribute.
"""
model_config = getattr(model, "config", DUMMY_MODEL_CONFIG)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
return model_config

def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules = []
model_config = self.get_model_config(model)
if model_config.get("tie_word_embeddings"):
for target_module in self.targeted_module_names:
if target_module in EMBEDDING_LAYER_NAMES:
tied_target_modules.append(target_module)
return tied_target_modules


class BaseTunerLayer(ABC):
r"""
Expand Down
4 changes: 1 addition & 3 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def _find_dim(self, config) -> tuple[int, int]:

This will be used for determining the size of the shared vera_A and vera_B matrices.
"""
model_config = getattr(self.model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = self.get_model_config(self.model)

peft_config = self._prepare_adapter_config(config, model_config)
peft_config = _maybe_include_all_linear_layers(peft_config, self.model)
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,4 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear"
TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
DUMMY_TARGET_MODULES = "dummy-target-modules"
DUMMY_MODEL_CONFIG = {"model_type": "custom"}
84 changes: 84 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
get_peft_model,
)
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
_maybe_include_all_linear_layers,
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
from peft.utils.constants import DUMMY_MODEL_CONFIG

from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu

Expand Down Expand Up @@ -1065,3 +1067,85 @@ def forward(self, X):

with pytest.raises(TypeError, match="get_model_status is not supported for PeftMixedModel"):
model.get_model_status()


# Tests for BaseTuner
class MockModelConfig:
config = {"mock_key": "mock_value"}

def to_dict(self):
return self.config


class ModelWithConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig()


class ModelWithDictConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig.config


class ModelWithNoConfig(nn.Module):
pass


class TestBaseTunerGetModelConfig(unittest.TestCase):
def test_get_model_config_use_to_dict(self):
config = BaseTuner.get_model_config(ModelWithConfig())
assert config == MockModelConfig.config

def test_get_model_config_as_dict(self):
config = BaseTuner.get_model_config(ModelWithDictConfig())
assert config == MockModelConfig.config

def test_get_model_config_with_no_config(self):
config = BaseTuner.get_model_config(ModelWithNoConfig())
assert config == DUMMY_MODEL_CONFIG


class TestBaseTunerWarnForTiedEmbeddings:
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
warn_end_inject = "huggingface/peft/issues/2018."
warn_end_merge = (
"# Now use the original model but in untied format\n"
"model = AutoModelForCausalLM.from_pretrained(untied_model_dir)\n```\n"
)

def _get_peft_model(self, tie_word_embeddings, target_module):
model = get_peft_model(
AutoModelForCausalLM.from_pretrained(self.model_id, tie_word_embeddings=tie_word_embeddings),
LoraConfig(target_modules=[target_module]),
)
return model

def _is_warn_triggered(self, warning_list, endswith):
return any(str(warning.message).endswith(endswith) for warning in warning_list)

def test_warn_for_tied_embeddings_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=True, target_module="lm_head")
assert self._is_warn_triggered(recwarn.list, self.warn_end_inject)

def test_warn_for_tied_embeddings_merge(self, recwarn):
model = self._get_peft_model(tie_word_embeddings=True, target_module="lm_head")
model.merge_and_unload()
assert self._is_warn_triggered(recwarn.list, self.warn_end_merge)

def test_no_warn_for_untied_embeddings_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=False, target_module="lm_head")
assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject)

def test_no_warn_for_untied_embeddings_merge(self, recwarn):
model_not_tied = self._get_peft_model(tie_word_embeddings=False, target_module="lm_head")
model_not_tied.merge_and_unload()
assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge)

def test_no_warn_for_no_target_module_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject)

def test_no_warn_for_no_target_module_merge(self, recwarn):
model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
model_no_target_module.merge_and_unload()
assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge)
Loading