Skip to content

Commit

Permalink
Move warn in BaseTuner
Browse files Browse the repository at this point in the history
  • Loading branch information
ltoniazzi authored and Lorenzo Toniazzi committed Aug 21, 2024
1 parent 45644ca commit 44a02de
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 22 deletions.
11 changes: 1 addition & 10 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner as _BaseTuner
from .utils import _prepare_prompt_learning_config, EMBEDDING_LAYER_NAMES
from .utils import _prepare_prompt_learning_config


if TYPE_CHECKING:
Expand Down Expand Up @@ -130,13 +130,6 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:

return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)

def warn_if_tied_embeddings_in_target_modules(model_config, peft_config):
if model_config.get("tie_word_embeddings"):
for target_module in peft_config.target_modules:
if target_module in EMBEDDING_LAYER_NAMES:
warnings.warn(
f"{model_config['tie_word_embeddings']=} and a tied {target_module=} is passed to peft config. This can lead to complications, for example when merging the adapter. Are you sure you want to use the target module {target_module}?"
)

def get_peft_model(
model: PreTrainedModel,
Expand Down Expand Up @@ -174,8 +167,6 @@ def get_peft_model(
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

warn_if_tied_embeddings_in_target_modules(model_config, peft_config)

if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
Expand Down
4 changes: 1 addition & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,9 +1231,7 @@ def create_or_update_model_card(self, output_dir: str):

card.data["library_name"] = "peft"

model_config = getattr(self, "config", None)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = BaseTuner.get_model_config(self, default=None)
if model_config is not None and "_name_or_path" in model_config:
card.data["base_model"] = model_config["_name_or_path"]

Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def _check_merge_allowed(self):
Currently gptq quantization and replicated layers do not support merging.
"""
self._warn_if_tied_embeddings_in_target_modules(self.model)
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
Expand Down
28 changes: 23 additions & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from transformers.pytorch_utils import Conv1D

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_TARGET_MODULES
from peft.utils.constants import DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES
from peft.utils.peft_types import PeftType

from ..config import PeftConfig
Expand Down Expand Up @@ -108,7 +108,6 @@ def onload_layer(layer):
offload_state_dict(safetensors_filename, layer.base_layer._hf_hook.weights_map)
layer.base_layer._hf_hook.post_forward(layer.base_layer, torch.tensor([]))


class BaseTuner(nn.Module, ABC):
r"""
A base tuner model that provides the common methods and attributes for all tuners that are injectable into a
Expand Down Expand Up @@ -387,9 +386,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
_check_for_modules_to_save = getattr(peft_config, "modules_to_save", None) is not None
_has_modules_to_save = False

model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = self.get_model_config(model)

peft_config = self._prepare_adapter_config(peft_config, model_config)

Expand Down Expand Up @@ -430,6 +427,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
parent, target, target_name = _get_submodules(model, key)
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

self._warn_if_tied_embeddings_in_target_modules(model=model)

# Handle X-LoRA case.
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
Expand Down Expand Up @@ -493,6 +492,25 @@ def _unloading_checks(self, adapter_names: Optional[list[str]]):
)
if is_modules_to_save_available and len(adapters_to_consider) > 1:
raise ValueError("Cannot unload multiple adapters that specify `modules_to_save`.")

@staticmethod
def get_model_config(model, default={"model_type": "custom"}):
model_config = getattr(model, "config", default)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
return model_config


def _warn_if_tied_embeddings_in_target_modules(self, model):
model_config = self.get_model_config(model)
if model_config.get("tie_word_embeddings"):
for target_module in self.targeted_module_names:
if target_module in EMBEDDING_LAYER_NAMES:
warnings.warn(
f"{model_config.get('tie_word_embeddings')=} and a tied {target_module=} is passed to peft config.\n"
"This can lead to complications, for example when merging the adapter.\n"
f"Are you sure you want to use the target module {target_module}?"
)


class BaseTunerLayer(ABC):
Expand Down
4 changes: 1 addition & 3 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def _find_dim(self, config) -> tuple[int, int]:
This will be used for determining the size of the shared vera_A and vera_B matrices.
"""
model_config = getattr(self.model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
model_config = self.get_model_config(self.model)

peft_config = self._prepare_adapter_config(config, model_config)
peft_config = _maybe_include_all_linear_layers(peft_config, self.model)
Expand Down
1 change: 0 additions & 1 deletion src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,3 @@
cast_mixed_precision_params,
)
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights
from .constants import EMBEDDING_LAYER_NAMES

0 comments on commit 44a02de

Please sign in to comment.