diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7c9123079c44f..8be786fd3f6f5 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -40,13 +40,13 @@ class Blip2ImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: (batch_size, num_channels, height, width)""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" class Blip2ImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 2d4f172ce0be6..b25f5d521a9bf 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -53,7 +53,7 @@ class ChameleonImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: `(batch_size, num_channels, height, width)`""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" def get_max_chameleon_image_tokens(ctx: InputContext): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7f213287f33b4..ca4d773190e0f 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -29,7 +29,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) IMG_START = '' @@ -42,19 +42,17 @@ class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: torch.Tensor """ - Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` - - Note that `num_patches` may be different for each batch, in which case - the data is passed as a list instead of a batched tensor. + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size, image_feature_size, hidden_size)` + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -357,7 +355,7 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.permute(0, 2, 1, 3).contiguous() return x - def extract_feature(self, pixel_values): + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] @@ -370,17 +368,7 @@ def extract_feature(self, pixel_values): vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != [2]: - raise ValueError( - f"The expected image sizes shape is batch dimension plus " - f"{[2]}. You supplied {data.shape}.") - - return data - - def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -389,10 +377,11 @@ def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) + expected_expr = str(expected_dims) raise ValueError( - "The expected shape of pixel values in each batch element " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) @@ -413,12 +402,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Flatten the B and N dimensions - image_embeds = image_embeds.flatten(0, 2) - return InternVLImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds), ) self.img_context_token_id = image_token_id[0] @@ -428,12 +414,10 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Flatten the B and N dimensions - pixel_values = pixel_values.flatten(0, 2) - return InternVLImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True).flatten(0, 1)), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 03a0abf1db481..490c93294d50f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -30,13 +30,13 @@ class LlavaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: `(batch_size, num_channels, height, width)`""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" class LlavaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 3a87242954114..048ca16974e3c 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -29,7 +29,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] """ - Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different for each batch, in which case - the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ image_sizes: NotRequired[torch.Tensor] """ - Shape: `(batch_size, 2)` + Shape: `(batch_size * num_images, 2)` This should be in `(height, width)` format. """ @@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -315,10 +316,19 @@ def __init__(self, torch.empty(config.text_config.hidden_size)) def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != [2]: - raise ValueError( - f"The expected image sizes shape is batch dimension plus " - f"{[2]}. You supplied {data.shape}.") + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data @@ -335,7 +345,7 @@ def _validate_shape(d: torch.Tensor): if actual_dims != expected_dims: expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( - "The expected shape of pixel values in each batch element " + "The expected shape of pixel values per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: @@ -357,22 +367,15 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - if not isinstance(image_sizes, torch.Tensor): + if not isinstance(image_sizes, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") - # Remove the N dimension until multiple images are supported. - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values.squeeze(1) - else: - pixel_values = [t.squeeze(0) for t in pixel_values] - - image_sizes = image_sizes.squeeze(1) - return LlavaNextImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), - image_sizes=self._validate_image_sizes(image_sizes), + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True)), ) if image_embeds is not None: @@ -380,12 +383,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) - return LlavaNextImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 0700f0c29d708..46ee4c3208b7a 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -34,13 +34,13 @@ class PaliGemmaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: (batch_size, num_channels, height, width)""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" class PaliGemmaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 61f1d73976379..bec1d35388506 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -44,7 +44,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal -from .utils import merge_multimodal_embeddings +from .utils import flatten_bn, merge_multimodal_embeddings logger = init_logger(__name__) @@ -75,15 +75,16 @@ class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] """ - Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different for each batch, in which case - the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ image_sizes: torch.Tensor """ - Shape: `(batch_size, 2)` + Shape: `(batch_size * num_images, 2)` This should be in `(height, width)` format. """ @@ -92,7 +93,7 @@ class Phi3VImagePixelInputs(TypedDict): class Phi3VImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -511,10 +512,19 @@ def __init__(self, self.sampler = Sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != [2]: - raise ValueError( - f"The expected shape of image sizes is batch dimension plus " - f"{[2]}. You supplied {tuple(data.shape)}.") + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data @@ -531,7 +541,7 @@ def _validate_shape(d: torch.Tensor): if actual_dims != expected_dims: expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( - "The expected shape of pixel values in each batch element " + "The expected shape of pixel values per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: @@ -556,30 +566,24 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - if not isinstance(image_sizes, torch.Tensor): + if not isinstance(image_sizes, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") - # Merge the B and N dimensions. - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values.flatten(0, 1) - else: - pixel_values = torch.cat(pixel_values) - - image_sizes = image_sizes.flatten(0, 1) - return Phi3VImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), - image_sizes=self._validate_image_sizes(image_sizes)) + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True))) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + return Phi3VImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c81c2fd114eb8..03d6223225511 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -49,7 +49,7 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size, 80, M)""" + """Shape: `(batch_size * num_audios, 80, M)""" class UltravoxAudioEmbeddingInputs(TypedDict): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 00026b7ebe2e1..6e7ee511bf27f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,4 +1,5 @@ -from typing import Dict, Iterable, List, Optional, Protocol, Tuple +from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, + Union, overload) import numpy as np import torch @@ -55,6 +56,44 @@ def init_vllm_registered_model( ) +@overload +def flatten_bn(x: torch.Tensor) -> torch.Tensor: + ... + + +@overload +def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: + ... + + +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: Literal[True], +) -> torch.Tensor: + ... + + +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. + + The input tensor should have shape ``(B, N, ...)```. + """ + if isinstance(x, torch.Tensor): + return x.flatten(0, 1) + + if concat: + return torch.cat(x) + + return [x_n for x_b in x for x_n in x_b] + + def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: """ Recursively concatenates NestedTensors along any heterogeneously sized @@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, This updates ``inputs_embeds`` in place. """ mask = (input_ids == placeholder_token_id) - num_expected_tokens = mask.sum() + num_expected_tokens = mask.sum().item() + assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) *dims, embed_dim = flattened.shape diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index f26e3292c264d..c02e61596927a 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -18,7 +18,7 @@ logger = init_logger(__name__) -NestedTensors = Union[List["NestedTensors"], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ @@ -61,7 +61,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: tensors_ = cast(List[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. - return stacked + return tensors_ return torch.stack(tensors_)