diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 966720a25..830564a57 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -37,6 +37,7 @@ from .evaluation import SentenceEvaluator from .fit_mixin import FitMixin from .models import Normalize, Pooling, Transformer +from .peft_mixin import PeftAdapterMixin from .quantization import quantize_embeddings from .util import ( batch_to_device, @@ -52,7 +53,7 @@ logger = logging.getLogger(__name__) -class SentenceTransformer(nn.Sequential, FitMixin): +class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin): """ Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings. diff --git a/sentence_transformers/peft_mixin.py b/sentence_transformers/peft_mixin.py new file mode 100644 index 000000000..da7f2494b --- /dev/null +++ b/sentence_transformers/peft_mixin.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from functools import wraps + +from transformers.integrations.peft import PeftAdapterMixin as PeftAdapterMixinTransformers + +from .models import Transformer + + +def peft_wrapper(func): + """Wrapper to call the method on the auto_model with a check for PEFT compatibility.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + self.check_peft_compatible_model() + method = getattr(self[0].auto_model, func.__name__) + return method(*args, **kwargs) + + return wrapper + + +class PeftAdapterMixin: + """ + Wrapper Mixin that adds the functionality to easily load and use adapters on the model. For + more details about adapters check out the documentation of PEFT + library: https://huggingface.co/docs/peft/index + + Currently supported PEFT methods follow those supported by transformers library, + you can find more information on: + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin + """ + + def has_peft_compatible_model(self) -> bool: + return isinstance(self[0], Transformer) and isinstance(self[0].auto_model, PeftAdapterMixinTransformers) + + def check_peft_compatible_model(self) -> None: + if not self.has_peft_compatible_model(): + raise ValueError( + "PEFT methods are only supported for Sentence Transformer models that use the Transformer module." + ) + + @peft_wrapper + def load_adapter(self, *args, **kwargs) -> None: + """ + Load adapter weights from file or remote Hub folder." If you are not familiar with adapters and PEFT methods, we + invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft + + Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter + **kwargs: + Keyword arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter + """ + ... # Implementation handled by the wrapper + + @peft_wrapper + def add_adapter(self, *args, **kwargs) -> None: + """ + Adds a fresh new adapter to the current model for training purposes. If no adapter name is passed, a default + name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the + default adapter name). + + Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter + **kwargs: + Keyword arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter + + """ + ... # Implementation handled by the wrapper + + @peft_wrapper + def set_adapter(self, *args, **kwargs) -> None: + """ + Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter + **kwargs: + Keyword arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter + """ + ... # Implementation handled by the wrapper + + @peft_wrapper + def disable_adapters(self) -> None: + """ + Disable all adapters that are attached to the model. This leads to inferring with the base model only. + """ + ... # Implementation handled by the wrapper + + @peft_wrapper + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model will use `self.active_adapter()` + """ + ... # Implementation handled by the wrapper + + @peft_wrapper + def active_adapters(self) -> list[str]: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters + for inference) returns the list of all active adapters so that users can deal with them accordingly. + + For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return + a single string. + """ + ... # Implementation handled by the wrapper + + @peft_wrapper + def active_adapter(self) -> str: ... # Implementation handled by the wrapper + + @peft_wrapper + def get_adapter_state_dict(self, *args, **kwargs) -> dict: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. + If no adapter_name is passed, the active adapter is used. + + Args: + *args: + Positional arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict + **kwargs: + Keyword arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation + https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict + """ + ... # Implementation handled by the wrapper diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index a86fc3816..790145ac0 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -16,8 +16,8 @@ import pytest import torch from huggingface_hub import CommitInfo, HfApi, RepoUrl -from peft import PeftModel from torch import nn +from transformers.utils import is_peft_available from sentence_transformers import SentenceTransformer, util from sentence_transformers.models import ( @@ -417,8 +417,9 @@ def transformers_init(*args, **kwargs): assert transformer_kwargs["model_args"]["attn_implementation"] == "eager" +@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test PEFT support.") def test_load_checkpoint_with_peft_and_lora() -> None: - from peft import LoraConfig, TaskType + from peft import LoraConfig, PeftModel, TaskType peft_config = LoraConfig( target_modules=["query", "key", "value"], @@ -431,7 +432,7 @@ def test_load_checkpoint_with_peft_and_lora() -> None: with SafeTemporaryDirectory() as tmp_folder: model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") - model._modules["0"].auto_model.add_adapter(peft_config) + model.add_adapter(peft_config) model.save(tmp_folder) expecteds = model.encode(["Hello there!", "How are you?"], convert_to_tensor=True) @@ -715,3 +716,68 @@ def test_empty_encode(stsb_bert_tiny_model: SentenceTransformer) -> None: model = stsb_bert_tiny_model embeddings = model.encode([]) assert embeddings.shape == (0,) + + +@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test adapter methods.") +def test_multiple_adapters() -> None: + text = "Hello, World!" + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + vec_initial = model.encode(text) + from peft import LoraConfig, TaskType, get_model_status + + # Adding a fresh adapter + peft_config = LoraConfig( + target_modules=["query", "key", "value"], + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights=False, # Random initialization to test the adapter + ) + model.add_adapter(peft_config) + + # Load an adapter from the hub + model.load_adapter("sentence-transformers-testing/stsb-bert-tiny-lora", "hub_adapter") + + # Adding another one with a different name + peft_config = LoraConfig( + target_modules=["value"], + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=2, + lora_alpha=16, + lora_dropout=0.1, + init_lora_weights=False, # Random initialization to test the adapter + ) + model.add_adapter(peft_config, "my_adapter") + + # Check that peft recognizes the adapters while we compute vectors for later comparison + status = get_model_status(model) + assert status.available_adapters == ["default", "hub_adapter", "my_adapter"] + assert status.enabled + assert status.active_adapters == ["my_adapter"] + assert status.active_adapters == model.active_adapters() + vec_my_adapter = model.encode(text) + + model.set_adapter("default") + status = get_model_status(model) + assert status.active_adapters == ["default"] + vec_default_adapter = model.encode(text) + + model.disable_adapters() + status = get_model_status(model) + assert not status.enabled + vec_no_adapter = model.encode(text) + + # Check that each vector is different + assert not np.allclose(vec_my_adapter, vec_default_adapter) + assert not np.allclose(vec_my_adapter, vec_no_adapter) + assert not np.allclose(vec_default_adapter, vec_no_adapter) + # Check that the vectors from the original model match + assert np.allclose(vec_initial, vec_no_adapter) + + # Check that for non Transformer-based models we have an error + model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency") + with pytest.raises(ValueError, match="PEFT methods are only supported"): + model.add_adapter(peft_config)