Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix model loading inconsistency after Peft training by using PeftModel #2980

Merged
merged 20 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"scikit-learn",
"scipy",
"huggingface-hub>=0.20.0",
"Pillow",
"Pillow"
]

[project.urls]
Expand All @@ -52,7 +52,7 @@ train = ["datasets", "accelerate>=0.20.3"]
onnx = ["optimum[onnxruntime]>=1.23.1"]
onnx-gpu = ["optimum[onnxruntime-gpu]>=1.23.1"]
openvino = ["optimum-intel[openvino]>=1.20.0"]
dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov"]
dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov", "peft"]

[build-system]
requires = ["setuptools>=42", "wheel"]
Expand Down
36 changes: 34 additions & 2 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
from transformers.utils import is_peft_available
from transformers.utils.import_utils import is_peft_available
from transformers.utils.peft_utils import find_adapter_config_file

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
if config_args is None:
config_args = {}

config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
config = self._load_config(model_name_or_path, cache_dir, backend, config_args)
self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)

if max_seq_length is not None and "model_max_length" not in tokenizer_args:
Expand All @@ -98,6 +99,26 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]):
"""Loads the configuration of a model"""
if find_adapter_config_file(model_name_or_path) is not None:
if not is_peft_available():
raise Exception(
"Loading a PEFT model requires installing the `peft` package. You can install it via `pip install peft`."
)
if backend != "torch":
# TODO: Consider following these steps automatically so we can load PEFT models with other backends
raise ValueError(
"PEFT models can currently only be loaded with the `torch` backend. "
'To use other backends, load the model with `backend="torch"`, call `model[0].auto_model.merge_and_unload()`, '
"save that model with `model.save_pretrained()` and then load the model with the desired backend."
)
from peft import PeftConfig

return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)

return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)

def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
"""Loads the transformer model"""
if backend == "torch":
Expand All @@ -109,13 +130,23 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args)
elif backend == "onnx":
self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
elif backend == "openvino":
self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args)
else:
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.")

def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
if is_peft_available():
from peft import PeftConfig, PeftModel

if isinstance(config, PeftConfig):
self.auto_model = PeftModel.from_pretrained(
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)

def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
if isinstance(config, T5Config) or isinstance(config, MT5Config):
raise ValueError("T5 models are not yet supported by the OpenVINO backend.")
Expand Down Expand Up @@ -305,6 +336,7 @@ def _backend_should_export(
logger.warning(
f"No {file_name!r} found in {load_path.as_posix()!r}. Exporting the model to {backend_name}."
)

if model_file_names:
logger.warning(
f"If you intended to load one of the {model_file_names} {backend_name} files, "
Expand Down
27 changes: 27 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from huggingface_hub import CommitInfo, HfApi, RepoUrl
from peft import PeftModel
from torch import nn

from sentence_transformers import SentenceTransformer, util
Expand Down Expand Up @@ -416,6 +417,32 @@ def transformers_init(*args, **kwargs):
assert transformer_kwargs["model_args"]["attn_implementation"] == "eager"


def test_load_checkpoint_with_peft_and_lora() -> None:
from peft import LoraConfig, TaskType

peft_config = LoraConfig(
target_modules=["query", "key", "value"],
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)

with SafeTemporaryDirectory() as tmp_folder:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
model._modules["0"].auto_model.add_adapter(peft_config)
model.save(tmp_folder)
expecteds = model.encode(["Hello there!", "How are you?"], convert_to_tensor=True)

loaded_peft_model = SentenceTransformer(tmp_folder)
actuals = loaded_peft_model.encode(["Hello there!", "How are you?"], convert_to_tensor=True)

assert isinstance(model._modules["0"].auto_model, nn.Module)
assert isinstance(loaded_peft_model._modules["0"].auto_model, PeftModel)
assert torch.equal(expecteds, actuals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
def test_encode_fp16() -> None:
tiny_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
Expand Down