diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c41903f84910d..b05cba3b5d423 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -281,7 +281,7 @@ Multimodal Language Models - * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - - Image\ :sup:`+` / Video\ :sup:`+` + - Image\ :sup:`E+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - * - :code:`UltravoxModel` diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index ca5b125369c85..3f4f01e3ae7ac 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -60,7 +60,24 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT for o in outputs: generated_text = o.outputs[0].text print(generated_text) + + # Inference with image embeddings as input with additional parameters + # Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding. + image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) + image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3) + mm_data['image'] = { + "image_embeds": image_embeds, + "image_grid_thw": image_grid_thw, + } + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + # Batch inference image_1 = PIL.Image.open(...) image_2 = PIL.Image.open(...) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f895e693b7107..c82e8ed6ed1e0 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -23,8 +23,8 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import lru_cache, partial -from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, - Union) +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Tuple, Type, TypedDict, Union) import torch import torch.nn as nn @@ -76,19 +76,31 @@ # === Vision Inputs === # -class Qwen2VLImageInputs(TypedDict): - pixel_values: torch.Tensor +class Qwen2VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ image_grid_thw: torch.Tensor """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. """ +class Qwen2VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, + Qwen2VLImageEmbeddingInputs] + + class Qwen2VLVideoInputs(TypedDict): pixel_values_videos: torch.Tensor """Shape: @@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl( data_type_key: str, ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" + if data_type_key == "image" and isinstance(data, dict): + return MultiModalInputs({ + "image_embeds": data.get("image_embeds"), + "image_grid_thw": data.get("image_grid_thw"), + }) model_config = ctx.model_config image_processor = cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code) @@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens( return llm_num_vision_tokens +def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, + data_type_key: str, image_processor: Any, + prompt_token_ids: List[int]) -> List[int]: + """ + Expand pad tokens for multi-modal inputs (e.g., images or videos). + + Args: + inputs (list): The multi-modal inputs (e.g., images or videos). + token_id (int): The token ID used to represent the multi-modal input. + make_batched_fn (Callable): A function to batch the inputs. + data_type_key (str): The type of the multi-modal input. + image_processor (Any): The image processor used to process the inputs. + prompt_token_ids (List[int]): The list of token IDs in the prompt. + + Returns: + List[int]: The list of token IDs for the multi-modal inputs. + """ + indices = [ + idx for idx, token in enumerate(prompt_token_ids) if token == token_id + ] + inputs = make_batched_fn(inputs) + assert len(indices) == len(inputs) + + prompt_token_ids_with_data = [] + for cnt, data in enumerate(inputs): + num_tokens = _get_llm_num_vision_tokens( + [data] if data_type_key == "image" else data, + data_type_key=data_type_key, + image_processor=image_processor, + ) + if cnt == 0: + end_idx = indices[cnt] + non_data_tokens = prompt_token_ids[:end_idx] + else: + non_data_tokens = prompt_token_ids[indices[cnt - 1] + + 1:indices[cnt]] + prompt_token_ids_with_data.extend(non_data_tokens) + prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) + prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) + return prompt_token_ids_with_data + + def input_processor_for_qwen2_vl(ctx: InputContext, llm_inputs: LLMInputs) -> LLMInputs: multi_modal_data = llm_inputs.get("multi_modal_data", None) @@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext, )["input_ids"] # Expand image pad tokens. + if image_inputs is not None: - image_indices = [ - idx for idx, token in enumerate(prompt_token_ids) - if token == hf_config.image_token_id - ] - image_inputs = make_batched_images(image_inputs) - assert len(image_indices) == len(image_inputs) - - prompt_token_ids_with_image = [] - for image_cnt, image in enumerate(image_inputs): - num_image_tokens = _get_llm_num_vision_tokens( - [image], - data_type_key="image", - image_processor=image_processor, - ) - if image_cnt == 0: - non_image_tokens = prompt_token_ids[:image_indices[image_cnt]] - else: - non_image_tokens = prompt_token_ids[image_indices[image_cnt - - 1] + - 1:image_indices[image_cnt]] - prompt_token_ids_with_image.extend(non_image_tokens) - prompt_token_ids_with_image.extend( - hf_config.image_token_id for _ in range(num_image_tokens)) - prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] + - 1:]) - prompt_token_ids = prompt_token_ids_with_image - - # Expand video pad tokens. + if isinstance(image_inputs, dict): + prompt_token_ids_with_image = [] + image_indices = [ + idx for idx, token in enumerate(prompt_token_ids) + if token == hf_config.image_token_id + ] + image_cnt = len(image_indices) + embed_dim = image_inputs.get('image_embeds').size(0) + assert embed_dim % image_cnt == 0 + num_pad_tokens = embed_dim // image_cnt + for idx, token in enumerate(prompt_token_ids): + if idx in image_indices: + prompt_token_ids_with_image.extend([token] * + num_pad_tokens) + else: + prompt_token_ids_with_image.append(token) + prompt_token_ids = prompt_token_ids_with_image + else: + prompt_token_ids = _expand_pad_tokens(image_inputs, + hf_config.image_token_id, + make_batched_images, "image", + image_processor, + prompt_token_ids) + if video_inputs is not None: - video_indices = [ - idx for idx, token in enumerate(prompt_token_ids) - if token == hf_config.video_token_id - ] - video_inputs = make_batched_videos(video_inputs) - assert len(video_indices) == len(video_inputs) - - prompt_token_ids_with_video = [] - for video_cnt, video in enumerate(video_inputs): - num_video_tokens = _get_llm_num_vision_tokens( - video, - data_type_key="video", - image_processor=image_processor, - ) - if video_cnt == 0: - non_video_tokens = prompt_token_ids[:video_indices[video_cnt]] - else: - non_video_tokens = prompt_token_ids[video_indices[video_cnt - - 1] + - 1:video_indices[video_cnt]] - prompt_token_ids_with_video.extend(non_video_tokens) - prompt_token_ids_with_video.extend( - hf_config.video_token_id for _ in range(num_video_tokens)) - prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] + - 1:]) - prompt_token_ids = prompt_token_ids_with_video + prompt_token_ids = _expand_pad_tokens(video_inputs, + hf_config.video_token_id, + make_batched_videos, "video", + image_processor, + prompt_token_ids) return LLMInputs( prompt_token_ids=prompt_token_ids, @@ -910,22 +945,32 @@ def _validate_and_reshape_mm_tensor(self, def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) - if pixel_values is None: + if pixel_values is None and image_embeds is None: return None - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") - return Qwen2VLImageInputs(pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2VLImagePixelInputs(type="pixel_values", + data=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2VLImageEmbeddingInputs(type="image_embeds", + data=image_embeds) def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: @@ -947,7 +992,10 @@ def _parse_and_validate_video_input( def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if image_input["type"] == "image_embeds": + return image_input["data"].type(self.visual.dtype) + + pixel_values = image_input["data"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) return image_embeds