Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] support input embeddings for qwen2vl #8856

Merged
Merged
162 changes: 95 additions & 67 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, Type,
TypedDict, Union)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -76,19 +76,31 @@
# === Vision Inputs === #


class Qwen2VLImageInputs(TypedDict):
pixel_values: torch.Tensor
class Qwen2VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""

image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`

"""Shape: `(num_images, 3)
whyiug marked this conversation as resolved.
Show resolved Hide resolved
This should be in `(grid_t, grid_h, grid_w)` format.
"""


class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
Qwen2VLImageEmbeddingInputs]


class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
Expand Down Expand Up @@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
data_type_key: str,
) -> MultiModalInputs:
"""Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict):
return MultiModalInputs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
Expand Down Expand Up @@ -775,62 +792,60 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
)["input_ids"]

# Expand image pad tokens.
if image_inputs is not None:
image_indices = [
def expand_pad_tokens(inputs, token_id, make_batched_fn, data_type_key):
whyiug marked this conversation as resolved.
Show resolved Hide resolved
indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
if token == token_id
]
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)

prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)

prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image

# Expand video pad tokens.
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id
for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data

if image_inputs is not None:
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_cnt = len(image_indices)
embed_dim = image_inputs.get('image_embeds').size(0)
assert embed_dim % image_cnt == 0
num_pad_tokens = embed_dim // image_cnt
for idx, token in enumerate(prompt_token_ids):
if idx in image_indices:
prompt_token_ids_with_image.extend([token] *
num_pad_tokens)
else:
prompt_token_ids_with_image.append(token)
prompt_token_ids = prompt_token_ids_with_image
else:
prompt_token_ids = expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images, "image")

if video_inputs is not None:
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
video_inputs = make_batched_videos(video_inputs)
assert len(video_indices) == len(video_inputs)

prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
else:
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
1] +
1:video_indices[video_cnt]]
prompt_token_ids_with_video.extend(non_video_tokens)
prompt_token_ids_with_video.extend(
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_video
prompt_token_ids = expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos, "video")

return LLMInputs(
prompt_token_ids=prompt_token_ids,
Expand Down Expand Up @@ -910,22 +925,32 @@ def _validate_and_reshape_mm_tensor(self,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")

return Qwen2VLImageInputs(pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
return Qwen2VLImagePixelInputs(type="pixel_values",
data=pixel_values,
image_grid_thw=image_grid_thw)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
data=image_embeds)

def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
Expand All @@ -947,7 +972,10 @@ def _parse_and_validate_video_input(

def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if image_input["type"] == "image_embeds":
return image_input["data"].type(self.visual.dtype)

pixel_values = image_input["data"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
Expand Down
Loading