Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: PEFT method registration function #2282

Merged
8 changes: 6 additions & 2 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
__version__ = "0.14.1.dev0"

from .auto import (
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
AutoPeftModel,
AutoPeftModelForCausalLM,
AutoPeftModelForFeatureExtraction,
Expand All @@ -25,12 +26,13 @@
)
from .config import PeftConfig, PromptLearningConfig
from .mapping import (
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
PEFT_TYPE_TO_CONFIG_MAPPING,
PEFT_TYPE_TO_MIXED_MODEL_MAPPING,
PEFT_TYPE_TO_TUNER_MAPPING,
get_peft_config,
get_peft_model,
inject_adapter_in_model,
)
from .mapping_func import get_peft_model
from .mixed_model import PeftMixedModel
from .peft_model import (
PeftModel,
Expand Down Expand Up @@ -112,6 +114,8 @@
__all__ = [
"MODEL_TYPE_TO_PEFT_MODEL_MAPPING",
"PEFT_TYPE_TO_CONFIG_MAPPING",
"PEFT_TYPE_TO_MIXED_MODEL_MAPPING",
"PEFT_TYPE_TO_TUNER_MAPPING",
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
"AdaLoraConfig",
"AdaLoraModel",
Expand Down
11 changes: 10 additions & 1 deletion src/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)

from .config import PeftConfig
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
Expand All @@ -43,6 +42,16 @@
from .utils.other import check_file_exists_on_hf_hub


MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
"TOKEN_CLS": PeftModelForTokenClassification,
"QUESTION_ANS": PeftModelForQuestionAnswering,
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
}


class _BaseAutoPeftModel:
_target_class = None
_target_peft_class = None
Expand Down
219 changes: 9 additions & 210 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,123 +14,23 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any

import torch

from peft.tuners.xlora.model import XLoraModel

from .config import PeftConfig
from .mixed_model import PeftMixedModel
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
PeftModelForFeatureExtraction,
PeftModelForQuestionAnswering,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from .tuners import (
AdaLoraConfig,
AdaLoraModel,
AdaptionPromptConfig,
BOFTConfig,
BOFTModel,
BoneConfig,
BoneModel,
CPTConfig,
CPTEmbedding,
FourierFTConfig,
FourierFTModel,
HRAConfig,
HRAModel,
IA3Config,
IA3Model,
LNTuningConfig,
LNTuningModel,
LoHaConfig,
LoHaModel,
LoKrConfig,
LoKrModel,
LoraConfig,
LoraModel,
MultitaskPromptTuningConfig,
OFTConfig,
OFTModel,
PolyConfig,
PolyModel,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
VBLoRAConfig,
VBLoRAModel,
VeraConfig,
VeraModel,
XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner, BaseTunerLayer
from .utils import _prepare_prompt_learning_config
from .utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
from .utils import PeftType


if TYPE_CHECKING:
from transformers import PreTrainedModel

from .config import PeftConfig
from .tuners.tuners_utils import BaseTuner

MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
"TOKEN_CLS": PeftModelForTokenClassification,
"QUESTION_ANS": PeftModelForQuestionAnswering,
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
}

PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
"ADAPTION_PROMPT": AdaptionPromptConfig,
"PROMPT_TUNING": PromptTuningConfig,
"PREFIX_TUNING": PrefixTuningConfig,
"P_TUNING": PromptEncoderConfig,
"LORA": LoraConfig,
"LOHA": LoHaConfig,
"LORAPLUS": LoraConfig,
"LOKR": LoKrConfig,
"ADALORA": AdaLoraConfig,
"BOFT": BOFTConfig,
"IA3": IA3Config,
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
"OFT": OFTConfig,
"POLY": PolyConfig,
"LN_TUNING": LNTuningConfig,
"VERA": VeraConfig,
"FOURIERFT": FourierFTConfig,
"XLORA": XLoraConfig,
"HRA": HRAConfig,
"VBLORA": VBLoRAConfig,
"CPT": CPTConfig,
"BONE": BoneConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
"LORA": LoraModel,
"LOHA": LoHaModel,
"LOKR": LoKrModel,
"ADALORA": AdaLoraModel,
"BOFT": BOFTModel,
"IA3": IA3Model,
"OFT": OFTModel,
"POLY": PolyModel,
"LN_TUNING": LNTuningModel,
"VERA": VeraModel,
"FOURIERFT": FourierFTModel,
"XLORA": XLoraModel,
"HRA": HRAModel,
"VBLORA": VBLoRAModel,
"CPT": CPTEmbedding,
"BONE": BoneModel,
}
# these will be filled by the register_peft_method function
PEFT_TYPE_TO_CONFIG_MAPPING: dict[PeftType, type[PeftConfig]] = {}
PEFT_TYPE_TO_TUNER_MAPPING: dict[PeftType, type[BaseTuner]] = {}
PEFT_TYPE_TO_MIXED_MODEL_MAPPING: dict[PeftType, type[BaseTuner]] = {}
PEFT_TYPE_TO_PREFIX_MAPPING: dict[PeftType, str] = {}


def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
Expand All @@ -144,107 +44,6 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)


def get_peft_model(
model: PreTrainedModel,
peft_config: PeftConfig,
adapter_name: str = "default",
mixed: bool = False,
autocast_adapter_dtype: bool = True,
revision: Optional[str] = None,
low_cpu_mem_usage: bool = False,
) -> PeftModel | PeftMixedModel:
"""
Returns a Peft model object from a model and a config, where the model will be modified in-place.

Args:
model ([`transformers.PreTrainedModel`]):
Model to be wrapped.
peft_config ([`PeftConfig`]):
Configuration object containing the parameters of the Peft model.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
mixed (`bool`, `optional`, defaults to `False`):
Whether to allow mixing different (compatible) adapter types.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
revision (`str`, `optional`, defaults to `main`):
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
False if you intend on training the model, unless the adapter weights will be replaced by different weights
before training starts.
"""
model_config = BaseTuner.get_model_config(model)
old_name = peft_config.base_model_name_or_path
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

# Especially in notebook environments there could be a case that a user wants to experiment with different
# configuration values. However, it is likely that there won't be any changes for new configs on an already
# initialized PEFT model. The best we can do is warn the user about it.
if any(isinstance(module, BaseTunerLayer) for module in model.modules()):
warnings.warn(
"You are trying to modify a model with PEFT for a second time. If you want to reload the model with a "
"different config, make sure to call `.unload()` before."
)

if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
"Please ensure that the correct base model is loaded when loading this checkpoint."
)

if revision is not None:
if peft_config.revision is not None and peft_config.revision != revision:
warnings.warn(
f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
)
peft_config.revision = revision

if (
(isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"]))
and (peft_config.init_lora_weights == "eva")
and not low_cpu_mem_usage
):
warnings.warn(
"lora with eva initialization used with low_cpu_mem_usage=False. "
"Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
)

prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
if prefix and adapter_name in prefix:
warnings.warn(
f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
"This may lead to reinitialization of the adapter weights during loading."
)

if mixed:
# note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
return PeftModel(
model,
peft_config,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)

if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
model,
peft_config,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)


def inject_adapter_in_model(
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
) -> torch.nn.Module:
Expand Down
Loading
Loading