diff --git a/src/peft/mapping.py b/src/peft/mapping.py index f5be9fdcc0..bb6e30436e 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -64,7 +64,7 @@ VeraModel, XLoraConfig, ) -from .tuners.tuners_utils import BaseTuner as _BaseTuner +from .tuners.tuners_utils import BaseTuner from .utils import _prepare_prompt_learning_config @@ -103,7 +103,7 @@ "HRA": HRAConfig, } -PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = { +PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = { "LORA": LoraModel, "LOHA": LoHaModel, "LOKR": LoKrModel, @@ -159,13 +159,11 @@ def get_peft_model( The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for the base model """ - model_config = getattr(model, "config", {"model_type": "custom"}) - if hasattr(model_config, "to_dict"): - model_config = model_config.to_dict() - + model_config = BaseTuner.get_model_config(model) old_name = peft_config.base_model_name_or_path new_name = model.__dict__.get("name_or_path", None) peft_config.base_model_name_or_path = new_name + if (old_name is not None) and (old_name != new_name): warnings.warn( f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index be57e8f9bc..b0324b4dbf 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -23,6 +23,8 @@ from torch import nn from transformers.utils import PushToHubMixin +from peft.utils.constants import DUMMY_MODEL_CONFIG + from .config import PeftConfig from .peft_model import PeftModel from .tuners import ( @@ -120,7 +122,7 @@ def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name) self.set_modules_to_save(peft_config, adapter_name) - self.config = getattr(model, "config", {"model_type": "custom"}) + self.config = getattr(model, "config", DUMMY_MODEL_CONFIG) # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index a4bc958099..b7d270b4aa 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -37,6 +37,8 @@ from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.utils import PushToHubMixin +from peft.utils.constants import DUMMY_MODEL_CONFIG + from . import __version__ from .config import PeftConfig from .tuners import ( @@ -1231,9 +1233,8 @@ def create_or_update_model_card(self, output_dir: str): card.data["library_name"] = "peft" - model_config = getattr(self, "config", None) - if hasattr(model_config, "to_dict"): - model_config = model_config.to_dict() + model_config = BaseTuner.get_model_config(self) + model_config = None if model_config == DUMMY_MODEL_CONFIG else model_config if model_config is not None and "_name_or_path" in model_config: card.data["base_model"] = model_config["_name_or_path"] diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 5f1b0a6b49..bea131bf91 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -447,6 +447,7 @@ def _check_merge_allowed(self): Currently gptq quantization and replicated layers do not support merging. """ + super()._check_merge_allowed() if getattr(self.model, "quantization_method", None) == "gptq": raise ValueError("Cannot merge LORA layers when the model is gptq quantized") if self.peft_config.get("layer_replication"): diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 0021279076..285b5c9410 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -17,6 +17,7 @@ import logging import os import re +import textwrap import warnings from abc import ABC, abstractmethod from contextlib import contextmanager @@ -30,7 +31,7 @@ from transformers.pytorch_utils import Conv1D from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND -from peft.utils.constants import DUMMY_TARGET_MODULES, SEQ_CLS_HEAD_NAMES +from peft.utils.constants import DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES, SEQ_CLS_HEAD_NAMES from peft.utils.peft_types import PeftType, TaskType from ..config import PeftConfig @@ -361,7 +362,36 @@ def _check_merge_allowed(self): Raise a ValueError if it is not possible to merge the adapter with the given configuration. """ - pass + example_code = textwrap.dedent( + """ + ```python + from transformers import AutoModelForCausalLM + + # Load original tied model + model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False) + + # Set the randomly initialized lm_head to the previously tied embeddings + model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone() + + # Save the untied model + untied_model_dir = "dir/for/untied/model" + model.save_pretrained(untied_model_dir) + model.config.save_pretrained(untied_model_dir) + + # Now use the original model but in untied format + model = AutoModelForCausalLM.from_pretrained(untied_model_dir) + ``` + """ + ) + tied_target_modules = self._get_tied_target_modules(self.model) + if tied_target_modules: + warnings.warn( + f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " + "This can lead to complications. " + "You can opt to merge the adapter after cloning the weights (to untie the embeddings). " + "You can untie the embeddings by loading the model with `tie_word_embeddings=False`. For example:" + + example_code + ) def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None: r""" @@ -387,9 +417,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d _check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None _has_modules_to_save = False - model_config = getattr(model, "config", {"model_type": "custom"}) - if hasattr(model_config, "to_dict"): - model_config = model_config.to_dict() + model_config = self.get_model_config(model) peft_config = self._prepare_adapter_config(peft_config, model_config) @@ -430,6 +458,15 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d parent, target, target_name = _get_submodules(model, key) self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) + tied_target_modules = self._get_tied_target_modules(model=model) + if tied_target_modules: + warnings.warn( + f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " + "This can lead to complications, for example when merging the adapter " + "or converting your model to formats other than safetensors. " + "See for example https://github.com/huggingface/peft/issues/2018." + ) + # Handle X-LoRA case. if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"): raise ValueError( @@ -494,6 +531,32 @@ def _unloading_checks(self, adapter_names: Optional[list[str]]): if is_modules_to_save_available and len(adapters_to_consider) > 1: raise ValueError("Cannot unload multiple adapters that specify `modules_to_save`.") + @staticmethod + def get_model_config(model: nn.Module) -> dict: + """ + This method gets the config from a model in dictionary form. If model has not attribute config, then this + method returns a default config. + + Args: + model (`nn.Module`): + Model to get the config from. + default (`dict|None`, *optional*):: + What to return if model does not have a config attribute. + """ + model_config = getattr(model, "config", DUMMY_MODEL_CONFIG) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + return model_config + + def _get_tied_target_modules(self, model: nn.Module) -> list[str]: + tied_target_modules = [] + model_config = self.get_model_config(model) + if model_config.get("tie_word_embeddings"): + for target_module in self.targeted_module_names: + if target_module in EMBEDDING_LAYER_NAMES: + tied_target_modules.append(target_module) + return tied_target_modules + class BaseTunerLayer(ABC): r""" diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index e31d8e0ebc..8ef3506738 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -107,9 +107,7 @@ def _find_dim(self, config) -> tuple[int, int]: This will be used for determining the size of the shared vera_A and vera_B matrices. """ - model_config = getattr(self.model, "config", {"model_type": "custom"}) - if hasattr(model_config, "to_dict"): - model_config = model_config.to_dict() + model_config = self.get_model_config(self.model) peft_config = self._prepare_adapter_config(config, model_config) peft_config = _maybe_include_all_linear_layers(peft_config, self.model) diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index a5d65fb3d6..b071503d8d 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -261,3 +261,4 @@ def starcoder_model_postprocess_past_key_value(past_key_values): INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear" TOKENIZER_CONFIG_NAME = "tokenizer_config.json" DUMMY_TARGET_MODULES = "dummy-target-modules" +DUMMY_MODEL_CONFIG = {"model_type": "custom"} diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index e008c3f308..e2766ec17a 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -43,12 +43,14 @@ get_peft_model, ) from peft.tuners.tuners_utils import ( + BaseTuner, BaseTunerLayer, _maybe_include_all_linear_layers, check_target_module_exists, inspect_matched_modules, ) from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device +from peft.utils.constants import DUMMY_MODEL_CONFIG from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu @@ -1065,3 +1067,85 @@ def forward(self, X): with pytest.raises(TypeError, match="get_model_status is not supported for PeftMixedModel"): model.get_model_status() + + +# Tests for BaseTuner +class MockModelConfig: + config = {"mock_key": "mock_value"} + + def to_dict(self): + return self.config + + +class ModelWithConfig(nn.Module): + def __init__(self): + self.config = MockModelConfig() + + +class ModelWithDictConfig(nn.Module): + def __init__(self): + self.config = MockModelConfig.config + + +class ModelWithNoConfig(nn.Module): + pass + + +class TestBaseTunerGetModelConfig(unittest.TestCase): + def test_get_model_config_use_to_dict(self): + config = BaseTuner.get_model_config(ModelWithConfig()) + assert config == MockModelConfig.config + + def test_get_model_config_as_dict(self): + config = BaseTuner.get_model_config(ModelWithDictConfig()) + assert config == MockModelConfig.config + + def test_get_model_config_with_no_config(self): + config = BaseTuner.get_model_config(ModelWithNoConfig()) + assert config == DUMMY_MODEL_CONFIG + + +class TestBaseTunerWarnForTiedEmbeddings: + model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + warn_end_inject = "huggingface/peft/issues/2018." + warn_end_merge = ( + "# Now use the original model but in untied format\n" + "model = AutoModelForCausalLM.from_pretrained(untied_model_dir)\n```\n" + ) + + def _get_peft_model(self, tie_word_embeddings, target_module): + model = get_peft_model( + AutoModelForCausalLM.from_pretrained(self.model_id, tie_word_embeddings=tie_word_embeddings), + LoraConfig(target_modules=[target_module]), + ) + return model + + def _is_warn_triggered(self, warning_list, endswith): + return any(str(warning.message).endswith(endswith) for warning in warning_list) + + def test_warn_for_tied_embeddings_inject(self, recwarn): + self._get_peft_model(tie_word_embeddings=True, target_module="lm_head") + assert self._is_warn_triggered(recwarn.list, self.warn_end_inject) + + def test_warn_for_tied_embeddings_merge(self, recwarn): + model = self._get_peft_model(tie_word_embeddings=True, target_module="lm_head") + model.merge_and_unload() + assert self._is_warn_triggered(recwarn.list, self.warn_end_merge) + + def test_no_warn_for_untied_embeddings_inject(self, recwarn): + self._get_peft_model(tie_word_embeddings=False, target_module="lm_head") + assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject) + + def test_no_warn_for_untied_embeddings_merge(self, recwarn): + model_not_tied = self._get_peft_model(tie_word_embeddings=False, target_module="lm_head") + model_not_tied.merge_and_unload() + assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge) + + def test_no_warn_for_no_target_module_inject(self, recwarn): + self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") + assert not self._is_warn_triggered(recwarn.list, self.warn_end_inject) + + def test_no_warn_for_no_target_module_merge(self, recwarn): + model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj") + model_no_target_module.merge_and_unload() + assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge)