Skip to content

Commit

Permalink
Put warn in base class - force dummy config
Browse files Browse the repository at this point in the history
  • Loading branch information
ltoniazzi authored and Lorenzo Toniazzi committed Aug 24, 2024
1 parent 761ae2c commit 516fc3c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 16 deletions.
4 changes: 3 additions & 1 deletion src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torch import nn
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG

from .config import PeftConfig
from .peft_model import PeftModel
from .tuners import (
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str
self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name)
self.set_modules_to_save(peft_config, adapter_name)

self.config = getattr(model, "config", {"model_type": "custom"})
self.config = getattr(model, "config", DUMMY_MODEL_CONFIG)

# the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
# numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
Expand Down
5 changes: 4 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG

from . import __version__
from .config import PeftConfig
from .tuners import (
Expand Down Expand Up @@ -1231,7 +1233,8 @@ def create_or_update_model_card(self, output_dir: str):

card.data["library_name"] = "peft"

model_config = BaseTuner.get_model_config(self, default=None)
model_config = BaseTuner.get_model_config(self)
model_config = None if model_config == DUMMY_MODEL_CONFIG else model_config
if model_config is not None and "_name_or_path" in model_config:
card.data["base_model"] = model_config["_name_or_path"]

Expand Down
10 changes: 1 addition & 9 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,7 @@ def _check_merge_allowed(self):
Currently gptq quantization and replicated layers do not support merging.
"""
tied_target_modules = self._get_tied_target_modules(self.model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications when merging the adapter. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings), "
"and then load the merged model with `tie_word_embeddings=False`: "
"\n```python\nAutoModelForCausalLM.from_pretrained('path/to/merged/model', tie_word_embeddings=False)\n```\n"
)
super()._check_merge_allowed()
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
Expand Down
19 changes: 14 additions & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from transformers.pytorch_utils import Conv1D

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES
from peft.utils.constants import DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES
from peft.utils.peft_types import PeftType

from ..config import PeftConfig
Expand Down Expand Up @@ -361,7 +361,15 @@ def _check_merge_allowed(self):
Raise a ValueError if it is not possible to merge the adapter with the given configuration.
"""
pass
tied_target_modules = self._get_tied_target_modules(self.model)
if tied_target_modules:
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications when merging the adapter. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings), "
"and then load the merged model with `tie_word_embeddings=False`: "
"\n```python\nAutoModelForCausalLM.from_pretrained('path/to/merged/model', tie_word_embeddings=False)\n```\n"
)

def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
r"""
Expand Down Expand Up @@ -501,22 +509,23 @@ def _unloading_checks(self, adapter_names: Optional[list[str]]):
raise ValueError("Cannot unload multiple adapters that specify `modules_to_save`.")

@staticmethod
def get_model_config(model: nn.Module, default: dict | None = {"model_type": "custom"}) -> dict:
def get_model_config(model: nn.Module) -> dict:
"""
This method gets the config from a model in dictionary form.
If model has not attribute config, then this method returns a default config.
Args:
model (`nn.Module`):
Model to get the config from.
default (`dict|None`, *optional*)::
What to return if model does not have a config attribute.
"""
model_config = getattr(model, "config", default)
model_config = getattr(model, "config", DUMMY_MODEL_CONFIG)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
return model_config

def _get_tied_target_modules(self, model):
def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules = []
model_config = self.get_model_config(model)
if model_config.get("tie_word_embeddings"):
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,4 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
INCLUDE_LINEAR_LAYERS_SHORTHAND = "all-linear"
TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
DUMMY_TARGET_MODULES = "dummy-target-modules"
DUMMY_MODEL_CONFIG = {"model_type": "custom"}

0 comments on commit 516fc3c

Please sign in to comment.