Skip to content

Commit

Permalink
[Model] Add multi-image input support for LLaVA-Next offline inference (
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong authored and root committed Aug 28, 2024
1 parent 235b5e5 commit e2bb9e0
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 52 deletions.
21 changes: 10 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]

PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]


def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
Expand Down Expand Up @@ -161,7 +165,7 @@ def example_encoder_decoder_prompts(
decoder prompt) tuple.
Returns:
* Encoder prompt list
* Decoder prompt list (reverse of encoder prompt list)
'''
Expand Down Expand Up @@ -578,8 +582,7 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
Expand Down Expand Up @@ -623,10 +626,8 @@ def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

Expand Down Expand Up @@ -676,10 +677,8 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
Expand Down
93 changes: 80 additions & 13 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,22 @@
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

_PREFACE = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions.")
_LIMIT_IMAGE_PER_PROMPT = 4

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
"[INST] <image>\nWhat's the content of the image? [/INST]",
"cherry_blossom":
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
"[INST] <image>\nWhat is the season? [/INST]",
})

models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Expand Down Expand Up @@ -114,19 +112,43 @@ def run_test(
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)


def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_model_len=10240,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]

with hf_runner(model, dtype=dtype,
Expand All @@ -136,7 +158,7 @@ def run_test(
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
Expand Down Expand Up @@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
Expand Down Expand Up @@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image

inputs = [(
[
"[INST] <image><image>\nDescribe 2 images. [/INST]",
"[INST] <image><image>\nDescribe 2 images. [/INST]",
"[INST] <image><image><image><image>\nDescribe 4 images. [/INST]",
"[INST] <image>\nWhat is the season? [/INST]"
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
35 changes: 34 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import numpy as np
import pytest
from PIL import Image
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import async_fetch_image, fetch_image
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens)

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
Expand Down Expand Up @@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],

data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)


@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)

test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]

for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
placeholder_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
8 changes: 4 additions & 4 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -84,7 +84,7 @@ def input_processor_for_clip(
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
Expand Down Expand Up @@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
Expand Down
14 changes: 12 additions & 2 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_list_of

from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
Expand Down Expand Up @@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_next_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
Expand Down Expand Up @@ -425,7 +433,10 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
other_patch_embeds = other_patch_embeds \
num_patches = num_patch_height * num_patch_width

# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)

if "unpad" in strategy:
Expand Down Expand Up @@ -496,7 +507,6 @@ def _process_image_input(
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:

if image_input["type"] == "image_embeds":
return [image_input["data"]]

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import math
from array import array
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union

import torch
from PIL import Image
Expand Down Expand Up @@ -93,7 +93,7 @@ def input_processor_for_siglip(
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
Expand Down
Loading

0 comments on commit e2bb9e0

Please sign in to comment.