From a95354a36ee65523a499b3eb42f70a4a0ea4322d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 4 Oct 2024 19:54:45 -0700 Subject: [PATCH 1/5] [Doc] Update README.md with Ray summit slides (#9088) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3c0d4da6080d3..f0b7ce02d556d 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 +- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users! - [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). From dac914b0d6bc36de4eb4bf70a9d20954560893ea Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 4 Oct 2024 21:45:38 -0700 Subject: [PATCH 2/5] [Bugfix] use blockmanagerv1 for encoder-decoder (#9084) Co-authored-by: Roger Wang --- vllm/engine/arg_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cae95d20ca23d..1623ebb3aa74c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -903,6 +903,11 @@ def create_engine_config(self) -> EngineConfig: "--enable-prefix-caching is currently not " "supported for multimodal models and has been disabled.") self.enable_prefix_caching = False + if model_config.is_encoder_decoder_model: + logger.warning( + "Block Manager v2 does not support encoder-decoder models" + " currently. Using Block Manager v1 as fallback.") + self.use_v2_block_manager = False cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else From 53b3a330273967a3c4124cbfef2cacac92f553ba Mon Sep 17 00:00:00 2001 From: hhzhang16 <54051230+hhzhang16@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:05:37 -0700 Subject: [PATCH 3/5] [Bugfix] Fixes Phi3v & Ultravox Multimodal EmbeddingInputs (#8979) --- vllm/model_executor/models/phi3v.py | 20 +++++++---- vllm/model_executor/models/ultravox.py | 48 ++++++++++++++++---------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index ebfffb25360cd..b875a83f876be 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -467,9 +467,10 @@ def input_processor_for_phi3v(ctx: InputContext, input_height=h, num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape + image_feature_size = [image_data.shape[0]] + image_data = [image_data] elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] + image_feature_size = [item.shape[0] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -611,9 +612,6 @@ def _parse_and_validate_image_input( image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) - if pixel_values is None: - return None - if pixel_values is None and image_embeds is None: return None @@ -650,7 +648,17 @@ def _process_image_input( ) -> torch.Tensor: if image_input["type"] == "image_embeds": - return image_input["data"] + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors;" + "this can be either a list of 2D tensors or a single 3D tensor." + ) assert self.vision_embed_tokens is not None image_embeds = self.vision_embed_tokens(image_input["data"], diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index daa6e72dd1002..101cf38c96b01 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -38,6 +38,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP @@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): if not isinstance(data, list): data = [data] + # If the audio inputs are embeddings, no need for preprocessing + if is_list_of(data, torch.Tensor, check="all"): + return MultiModalInputs({"audio_embeds": data}) + audio_features = [] for audio_input in data: if not isinstance(audio_input, tuple): @@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): audios = [audios] audio_token_counts = [] - for audio_data, sample_rate in audios: - audio_length = audio_data.shape[0] - if sample_rate != feature_extractor.sampling_rate: - # Account for resampling. - adjustment = feature_extractor.sampling_rate / sample_rate - audio_length = math.ceil(adjustment * audio_length) - - feature_extractor_output_length = math.ceil( - (audio_length - (feature_extractor.hop_length - 1)) / - feature_extractor.hop_length) - - uv_config = ctx.get_hf_config(UltravoxConfig) - audio_num_tokens = min( - max( - 1, - math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), - get_ultravox_max_audio_tokens(ctx)) - audio_token_counts.append(audio_num_tokens) + for audio in audios: + if isinstance(audio, torch.Tensor): + audio_num_tokens = audio.shape[1] + audio_token_counts.append(audio_num_tokens) + else: + audio_data, sample_rate = audio + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - (feature_extractor.hop_length - 1)) / + feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + audio_token_counts.append(audio_num_tokens) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) From 15986f598c7b1f2969918c92f5c4cf7e28d5c0df Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Fri, 4 Oct 2024 23:57:05 -0700 Subject: [PATCH 4/5] [Model] Support Gemma2 embedding model (#9004) --- tests/conftest.py | 1 + .../embedding/language/test_embedding.py | 11 ++- vllm/model_executor/models/gemma2.py | 7 +- .../model_executor/models/gemma2_embedding.py | 82 +++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 5 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/gemma2_embedding.py diff --git a/tests/conftest.py b/tests/conftest.py index b1833fdae5347..177b8a0640278 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -277,6 +277,7 @@ def __init__( SentenceTransformer( model_name, device="cpu", + trust_remote_code=True, ).to(dtype=torch_dtype)) else: model_kwargs = model_kwargs if model_kwargs is not None else {} diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 6556998b68a74..be316c6e12da1 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -1,6 +1,6 @@ """Compare the outputs of HF and vLLM for Mistral models using greedy sampling. -Run `pytest tests/models/test_llama_embedding.py`. +Run `pytest tests/models/embedding/language/test_embedding.py`. """ import pytest import torch @@ -8,6 +8,7 @@ MODELS = [ "intfloat/e5-mistral-7b-instruct", + "BAAI/bge-multilingual-gemma2", ] @@ -28,6 +29,14 @@ def test_models( model: str, dtype: str, ) -> None: + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 9fddaac3a0837..ddeaa0fbfc276 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -278,11 +278,14 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) hidden_states *= self.normalizer - residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/gemma2_embedding.py b/vllm/model_executor/models/gemma2_embedding.py new file mode 100644 index 0000000000000..1bcdaea93410f --- /dev/null +++ b/vllm/model_executor/models/gemma2_embedding.py @@ -0,0 +1,82 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.gemma2 import Gemma2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class Gemma2EmbeddingModel(nn.Module): + """A model that uses Gemma2 with additional embedding functionalities. + + This class encapsulates the Gemma2Model and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of Gemma2Model used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = Gemma2Model(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a72b9e8909db2..ccb0e155ff4aa 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -83,6 +83,7 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), } _MULTIMODAL_MODELS = { From cfadb9c68798c0cc4d674de19970a8e3b5ea1273 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 5 Oct 2024 06:56:40 -0700 Subject: [PATCH 5/5] [Bugfix] Deprecate registration of custom configs to huggingface (#9083) --- .../models/decoder_only/vision_language/test_internvl.py | 3 ++- .../models/encoder_decoder/vision_language/test_mllama.py | 7 ------- vllm/transformers_utils/config.py | 8 -------- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_internvl.py b/tests/models/decoder_only/vision_language/test_internvl.py index a756f8214edee..49cab75d8ea53 100644 --- a/tests/models/decoder_only/vision_language/test_internvl.py +++ b/tests/models/decoder_only/vision_language/test_internvl.py @@ -97,7 +97,8 @@ def __init__(self, hf_runner: HfRunner): self.tokenizer = hf_runner.tokenizer self.dtype = hf_runner.model.dtype - self.config = AutoConfig.from_pretrained(hf_runner.model_name) + self.config = AutoConfig.from_pretrained(hf_runner.model_name, + trust_remote_code=True) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 254185537e403..78a5c8158e16e 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -195,11 +195,6 @@ def _run_test( def process(hf_inputs: BatchEncoding): return hf_inputs - from transformers.models.mllama import MllamaConfig as MllamaConfigHf - - # use transformer's MllamaConfig for hf_runner - # and vllm's MllamaConfig for vllm_runner - AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True) with hf_runner(model, dtype=dtype, model_kwargs={"device_map": "auto"}, @@ -213,8 +208,6 @@ def process(hf_inputs: BatchEncoding): for prompts, images in inputs ] - from vllm.transformers_utils.configs.mllama import MllamaConfig - AutoConfig.register("mllama", MllamaConfig, exist_ok=True) for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): check_logprobs_close( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 0f20e8d0c8213..bfba4ca77e1fe 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,4 +1,3 @@ -import contextlib import enum import json from pathlib import Path @@ -61,13 +60,6 @@ **_CONFIG_REGISTRY_OVERRIDE_HF } -for name, cls in _CONFIG_REGISTRY.items(): - with contextlib.suppress(ValueError): - if name in _CONFIG_REGISTRY_OVERRIDE_HF: - AutoConfig.register(name, cls, exist_ok=True) - else: - AutoConfig.register(name, cls) - class ConfigFormat(str, enum.Enum): AUTO = "auto"