diff --git a/pyproject.toml b/pyproject.toml index 11550a5e7..a161c51f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "scikit-learn", "scipy", "huggingface-hub>=0.20.0", - "Pillow", + "Pillow" ] [project.urls] @@ -52,7 +52,7 @@ train = ["datasets", "accelerate>=0.20.3"] onnx = ["optimum[onnxruntime]>=1.23.1"] onnx-gpu = ["optimum[onnxruntime-gpu]>=1.23.1"] openvino = ["optimum-intel[openvino]>=1.20.0"] -dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov"] +dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov", "peft"] [build-system] requires = ["setuptools>=42", "wheel"] diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index b2b5053a3..6caf94dc4 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -11,7 +11,8 @@ import torch from torch import nn from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config -from transformers.utils import is_peft_available +from transformers.utils.import_utils import is_peft_available +from transformers.utils.peft_utils import find_adapter_config_file logger = logging.getLogger(__name__) @@ -73,7 +74,7 @@ def __init__( if config_args is None: config_args = {} - config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) + config = self._load_config(model_name_or_path, cache_dir, backend, config_args) self._load_model(model_name_or_path, config, cache_dir, backend, **model_args) if max_seq_length is not None and "model_max_length" not in tokenizer_args: @@ -98,6 +99,26 @@ def __init__( if tokenizer_name_or_path is not None: self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__ + def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]): + """Loads the configuration of a model""" + if find_adapter_config_file(model_name_or_path) is not None: + if not is_peft_available(): + raise Exception( + "Loading a PEFT model requires installing the `peft` package. You can install it via `pip install peft`." + ) + if backend != "torch": + # TODO: Consider following these steps automatically so we can load PEFT models with other backends + raise ValueError( + "PEFT models can currently only be loaded with the `torch` backend. " + 'To use other backends, load the model with `backend="torch"`, call `model[0].auto_model.merge_and_unload()`, ' + "save that model with `model.save_pretrained()` and then load the model with the desired backend." + ) + from peft import PeftConfig + + return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) + + return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) + def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None: """Loads the transformer model""" if backend == "torch": @@ -109,6 +130,7 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar self.auto_model = AutoModel.from_pretrained( model_name_or_path, config=config, cache_dir=cache_dir, **model_args ) + self._load_peft_model(model_name_or_path, config, cache_dir, **model_args) elif backend == "onnx": self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args) elif backend == "openvino": @@ -116,6 +138,15 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar else: raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.") + def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + if is_peft_available(): + from peft import PeftConfig, PeftModel + + if isinstance(config, PeftConfig): + self.auto_model = PeftModel.from_pretrained( + self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args + ) + def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: if isinstance(config, T5Config) or isinstance(config, MT5Config): raise ValueError("T5 models are not yet supported by the OpenVINO backend.") @@ -305,6 +336,7 @@ def _backend_should_export( logger.warning( f"No {file_name!r} found in {load_path.as_posix()!r}. Exporting the model to {backend_name}." ) + if model_file_names: logger.warning( f"If you intended to load one of the {model_file_names} {backend_name} files, " diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index c3aaee807..bbac4f499 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -16,6 +16,7 @@ import pytest import torch from huggingface_hub import CommitInfo, HfApi, RepoUrl +from peft import PeftModel from torch import nn from sentence_transformers import SentenceTransformer, util @@ -416,6 +417,32 @@ def transformers_init(*args, **kwargs): assert transformer_kwargs["model_args"]["attn_implementation"] == "eager" +def test_load_checkpoint_with_peft_and_lora() -> None: + from peft import LoraConfig, TaskType + + peft_config = LoraConfig( + target_modules=["query", "key", "value"], + task_type=TaskType.FEATURE_EXTRACTION, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + + with SafeTemporaryDirectory() as tmp_folder: + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + model._modules["0"].auto_model.add_adapter(peft_config) + model.save(tmp_folder) + expecteds = model.encode(["Hello there!", "How are you?"], convert_to_tensor=True) + + loaded_peft_model = SentenceTransformer(tmp_folder) + actuals = loaded_peft_model.encode(["Hello there!", "How are you?"], convert_to_tensor=True) + + assert isinstance(model._modules["0"].auto_model, nn.Module) + assert isinstance(loaded_peft_model._modules["0"].auto_model, PeftModel) + assert torch.equal(expecteds, actuals) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.") def test_encode_fp16() -> None: tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")