Skip to content

Commit

Permalink
ENH Warn if using tied target modules (#2025)
Browse files Browse the repository at this point in the history
When users are targetting tied weights (e.g. embedding and LM head),
merging the adapter will lead to errors. Now users are warned about the
possibility when they create such a PEFT model and also when they try to
merge.
  • Loading branch information
ltoniazzi authored Aug 29, 2024
1 parent 850eeb5 commit 679bcd8
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 18 deletions.
10 changes: 4 additions & 6 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
VeraModel,
XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner as _BaseTuner
from .tuners.tuners_utils import BaseTuner
from .utils import _prepare_prompt_learning_config


Expand Down Expand Up @@ -103,7 +103,7 @@
"HRA": HRAConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
"LORA": LoraModel,
"LOHA": LoHaModel,
"LOKR": LoKrModel,
Expand Down Expand Up @@ -159,13 +159,11 @@ def get_peft_model(
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
"""
model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()

model_config = BaseTuner.get_model_config(model)
old_name = peft_config.base_model_name_or_path
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
Expand Down
4 changes: 3 additions & 1 deletion src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torch import nn
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG

from .config import PeftConfig
from .peft_model import PeftModel
from .tuners import (
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str
self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name)
self.set_modules_to_save(peft_config, adapter_name)

self.config = getattr(model, "config", {"model_type": "custom"})
self.config = getattr(model, "config", DUMMY_MODEL_CONFIG)

# the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
# numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
Expand Down
7 changes: 4 additions & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG

from . import __version__
from .config import PeftConfig
from .tuners import (
Expand Down Expand Up @@ -1231,9 +1233,8 @@ def create_or_update_model_card(self, output_dir: str):

card.data["library_name"] = "peft"

model_config = getattr(self, "config", None)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = BaseTuner.get_model_config(self)
model_config = None if model_config == DUMMY_MODEL_CONFIG else model_config
if model_config is not None and "_name_or_path" in model_config:
card.data["base_model"] = model_config["_name_or_path"]

Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def _check_merge_allowed(self):
Currently gptq quantization and replicated layers do not support merging.
"""
super()._check_merge_allowed()
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
Expand Down
73 changes: 68 additions & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import re
import textwrap
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
Expand All @@ -30,7 +31,7 @@
from transformers.pytorch_utils import Conv1D

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES, SEQ_CLS_HEAD_NAMES
from peft.utils.constants import DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES, SEQ_CLS_HEAD_NAMES
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
Expand Down Expand Up @@ -361,7 +362,36 @@ def _check_merge_allowed(self):
Raise a ValueError if it is not possible to merge the adapter with the given configuration.
"""
pass
example_code = textwrap.dedent(
"""
```python
from transformers import AutoModelForCausalLM
# Load original tied model
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# Save the untied model
untied_model_dir = "dir/for/untied/model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
```
"""
)
tied_target_modules = self._get_tied_target_modules(self.model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings). "
"You can untie the embeddings by loading the model with `tie_word_embeddings=False`. For example:"
+ example_code
)

def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
r"""
Expand All @@ -387,9 +417,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False

model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = self.get_model_config(model)

peft_config = self._prepare_adapter_config(peft_config, model_config)

Expand Down Expand Up @@ -430,6 +458,15 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
parent, target, target_name = _get_submodules(model, key)
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

tied_target_modules = self._get_tied_target_modules(model=model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"See for example https://github.com/huggingface/peft/issues/2018."
)

# Handle X-LoRA case.
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
Expand Down Expand Up @@ -494,6 +531,32 @@ def _unloading_checks(self, adapter_names: Optional[list[str]]):
if is_modules_to_save_available and len(adapters_to_consider) > 1:
raise ValueError("Cannot unload multiple adapters that specify `modules_to_save`.")

@staticmethod
def get_model_config(model: nn.Module) -> dict:
"""
This method gets the config from a model in dictionary form. If model has not attribute config, then this
method returns a default config.
Args:
model (`nn.Module`):
Model to get the config from.
default (`dict|None`, *optional*)::
What to return if model does not have a config attribute.
"""
model_config = getattr(model, "config", DUMMY_MODEL_CONFIG)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
return model_config

def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules = []
model_config = self.get_model_config(model)
if model_config.get("tie_word_embeddings"):
for target_module in self.targeted_module_names:
if target_module in EMBEDDING_LAYER_NAMES:
tied_target_modules.append(target_module)
return tied_target_modules


class BaseTunerLayer(ABC):
r"""
Expand Down
4 changes: 1 addition & 3 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def _find_dim(self, config) -> tuple[int, int]:
This will be used for determining the size of the shared vera_A and vera_B matrices.
"""
model_config = getattr(self.model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = self.get_model_config(self.model)

peft_config = self._prepare_adapter_config(config, model_config)
peft_config = _maybe_include_all_linear_layers(peft_config, self.model)
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,4 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear"
TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
DUMMY_TARGET_MODULES = "dummy-target-modules"
DUMMY_MODEL_CONFIG = {"model_type": "custom"}
84 changes: 84 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
get_peft_model,
)
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
_maybe_include_all_linear_layers,
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
from peft.utils.constants import DUMMY_MODEL_CONFIG

from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu

Expand Down Expand Up @@ -1065,3 +1067,85 @@ def forward(self, X):

with pytest.raises(TypeError, match="get_model_status is not supported for PeftMixedModel"):
model.get_model_status()


# Tests for BaseTuner
class MockModelConfig:
config = {"mock_key": "mock_value"}

def to_dict(self):
return self.config


class ModelWithConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig()


class ModelWithDictConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig.config


class ModelWithNoConfig(nn.Module):
pass


class TestBaseTunerGetModelConfig(unittest.TestCase):
def test_get_model_config_use_to_dict(self):
config = BaseTuner.get_model_config(ModelWithConfig())
assert config == MockModelConfig.config

def test_get_model_config_as_dict(self):
config = BaseTuner.get_model_config(ModelWithDictConfig())
assert config == MockModelConfig.config

def test_get_model_config_with_no_config(self):
config = BaseTuner.get_model_config(ModelWithNoConfig())
assert config == DUMMY_MODEL_CONFIG


class TestBaseTunerWarnForTiedEmbeddings:
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
warn_end_inject = "huggingface/peft/issues/2018."
warn_end_merge = (
"# Now use the original model but in untied format\n"
"model = AutoModelForCausalLM.from_pretrained(untied_model_dir)\n```\n"
)

def _get_peft_model(self, tie_word_embeddings, target_module):
model = get_peft_model(
AutoModelForCausalLM.from_pretrained(self.model_id, tie_word_embeddings=tie_word_embeddings),
LoraConfig(target_modules=[target_module]),
)
return model

def _is_warn_triggered(self, warning_list, endswith):
return any(str(warning.message).endswith(endswith) for warning in warning_list)

def test_warn_for_tied_embeddings_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=True, target_module="lm_head")
assert self._is_warn_triggered(recwarn.list, self.warn_end_inject)

def test_warn_for_tied_embeddings_merge(self, recwarn):
model = self._get_peft_model(tie_word_embeddings=True, target_module="lm_head")
model.merge_and_unload()
assert self._is_warn_triggered(recwarn.list, self.warn_end_merge)

def test_no_warn_for_untied_embeddings_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=False, target_module="lm_head")
assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject)

def test_no_warn_for_untied_embeddings_merge(self, recwarn):
model_not_tied = self._get_peft_model(tie_word_embeddings=False, target_module="lm_head")
model_not_tied.merge_and_unload()
assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge)

def test_no_warn_for_no_target_module_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject)

def test_no_warn_for_no_target_module_merge(self, recwarn):
model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
model_no_target_module.merge_and_unload()
assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge)

0 comments on commit 679bcd8

Please sign in to comment.