From e2c7d044969ced0d7667a5c420a10b79105aa9b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Sat, 11 Jan 2025 15:05:09 +0100 Subject: [PATCH] [Bugfix] Fix RobertaModel loading (#11940) Signed-off-by: NickLucche Signed-off-by: Bowen Wang --- .../test_model_load_with_params.py | 27 +++++++++- .../embedding/language/test_embedding.py | 1 + vllm/model_executor/models/roberta.py | 51 +++++++++++++++---- 3 files changed, 67 insertions(+), 12 deletions(-) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index ed321ba9f00c1..0609fd96825e3 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -2,7 +2,7 @@ import pytest -from vllm.model_executor.layers.pooler import PoolingType +from vllm.model_executor.layers.pooler import CLSPool, PoolingType from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.platforms import current_platform @@ -92,3 +92,28 @@ def test_roberta_model_loading_with_params(vllm_runner): # assert output assert output + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_facebook_roberta_model_loading_with_params(vllm_runner): + """ + Test loading roberta-base model with no lm_head. + """ + model_name = "FacebookAI/roberta-base" + with vllm_runner(model_name=model_name, + dtype="float16", + max_model_len=MAX_MODEL_LEN) as model: + output = model.encode("Write a short story about a robot that" + " dreams for the first time.\n") + + model_tokenizer = model.model.llm_engine.tokenizer + assert model_tokenizer.tokenizer_id == model_name + + model = model.model.llm_engine.model_executor\ + .driver_worker.model_runner.model + assert not hasattr(model, "lm_head") + assert isinstance(model, RobertaEmbeddingModel) + assert isinstance(model._pooler, CLSPool) + + assert output diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 7749806548cd9..04ab4dd7371a3 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -25,6 +25,7 @@ pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), + pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ba1a78ac640fd..5997a76890c9d 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,3 +1,4 @@ +import itertools from typing import Iterable, List, Optional, Tuple import torch @@ -20,6 +21,30 @@ from .interfaces import SupportsCrossEncoding +def roberta_task_weights_filter( + all_weights: Iterable[Tuple[str, torch.Tensor]] +) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, + torch.Tensor]]]: + """ + Separate task-specific weights that are applied on top + of the encoder-decoder bert base. + To do so, return two generators over the original iterator. + Also, remove the "roberta." prefix to make it loadable + from vanilla BertModel. + """ + # Copy of a lazy iterator without in-memory overhead so both + # iterators can be iterated upon independently. + all_weights1, all_weights2 = itertools.tee(all_weights) + + def encoder_decoder_weights(): + for name, weight in all_weights1: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + + return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 + if not n.startswith("roberta.")) + + class RobertaEmbedding(nn.Module): def __init__(self, config: RobertaConfig): @@ -152,6 +177,18 @@ def _build_model(self, prefix=prefix, embedding_class=RobertaEmbedding) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + # Separate weights in "roberta"-prefixed and all else (not in memory). + # For use with models like FacebookAI/roberta-base. + bert_weights, task_weights = roberta_task_weights_filter(weights) + loaded = self.model.load_weights(bert_weights) + if not len(loaded): + # Fix for models like `sentence-transformers/stsb-roberta-base-v2` + # which use the same architecture, but have no "roberta" prefix. + loaded = self.model.load_weights(task_weights) + assert len(loaded), "Unable to load RobertaEmbeddingModel" + class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. @@ -181,20 +218,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - self_weights = [] - - def weight_filter(): - for name, weight in weights: - if name.startswith("roberta."): - yield (name[len("roberta."):], weight) - else: - self_weights.append((name, weight)) - - self.roberta.load_weights(weight_filter()) + bert_weights, task_weights = roberta_task_weights_filter(weights) + self.roberta.load_weights(bert_weights) params_dict = dict(self.named_parameters()) - for name, loaded_weight in self_weights: + for name, loaded_weight in task_weights: if name.startswith("classifier"): param = params_dict[name] weight_loader = getattr(param, "weight_loader",