Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 authored Oct 5, 2024
2 parents 7d223b5 + cfadb9c commit f327d91
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 44 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Easy, fast, and cheap LLM serving for everyone


*Latest News* 🔥
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users!
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def __init__(
SentenceTransformer(
model_name,
device="cpu",
trust_remote_code=True,
).to(dtype=torch_dtype))
else:
model_kwargs = model_kwargs if model_kwargs is not None else {}
Expand Down
3 changes: 2 additions & 1 deletion tests/models/decoder_only/vision_language/test_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(self, hf_runner: HfRunner):
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
Expand Down
11 changes: 10 additions & 1 deletion tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_llama_embedding.py`.
Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F

MODELS = [
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2",
]


Expand All @@ -28,6 +29,14 @@ def test_models(
model: str,
dtype: str,
) -> None:
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]

with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

Expand Down
7 changes: 0 additions & 7 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,6 @@ def _run_test(
def process(hf_inputs: BatchEncoding):
return hf_inputs

from transformers.models.mllama import MllamaConfig as MllamaConfigHf

# use transformer's MllamaConfig for hf_runner
# and vllm's MllamaConfig for vllm_runner
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
with hf_runner(model,
dtype=dtype,
model_kwargs={"device_map": "auto"},
Expand All @@ -213,8 +208,6 @@ def process(hf_inputs: BatchEncoding):
for prompts, images in inputs
]

from vllm.transformers_utils.configs.mllama import MllamaConfig
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,11 @@ def create_engine_config(self) -> EngineConfig:
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False
if model_config.is_encoder_decoder_model:
logger.warning(
"Block Manager v2 does not support encoder-decoder models"
" currently. Using Block Manager v1 as fallback.")
self.use_v2_block_manager = False

cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,14 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer

residual = None
else:
assert intermediate_tensors is not None
Expand Down
82 changes: 82 additions & 0 deletions vllm/model_executor/models/gemma2_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn

from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput


class Gemma2EmbeddingModel(nn.Module):
"""A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.forward(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
20 changes: 14 additions & 6 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,10 @@ def input_processor_for_phi3v(ctx: InputContext,
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_size = [image_data.shape[0]]
image_data = [image_data]
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
image_feature_size = [item.shape[0] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

Expand Down Expand Up @@ -611,9 +612,6 @@ def _parse_and_validate_image_input(
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
return None

if pixel_values is None and image_embeds is None:
return None

Expand Down Expand Up @@ -650,7 +648,17 @@ def _process_image_input(
) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]
image_data = image_input["data"]
if is_list_of(image_data, torch.Tensor):
# it's already a list of tensors
return image_data
if len(image_data.shape) == 3:
# 3D tensor
return list(torch.unbind(image_data, dim=0))
raise ValueError(
"We expect batched 2D tensors;"
"this can be either a list of 2D tensors or a single 3D tensor."
)

assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
}

_MULTIMODAL_MODELS = {
Expand Down
48 changes: 29 additions & 19 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP

Expand Down Expand Up @@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]

# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalInputs({"audio_embeds": data})

audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
Expand Down Expand Up @@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
audios = [audios]

audio_token_counts = []
for audio_data, sample_rate in audios:
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)

tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

Expand Down
8 changes: 0 additions & 8 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import enum
import json
from pathlib import Path
Expand Down Expand Up @@ -61,13 +60,6 @@
**_CONFIG_REGISTRY_OVERRIDE_HF
}

for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
if name in _CONFIG_REGISTRY_OVERRIDE_HF:
AutoConfig.register(name, cls, exist_ok=True)
else:
AutoConfig.register(name, cls)


class ConfigFormat(str, enum.Enum):
AUTO = "auto"
Expand Down

0 comments on commit f327d91

Please sign in to comment.