Skip to content

Commit

Permalink
[Model] Add multi-image support for minicpmv (vllm-project#7122)
Browse files Browse the repository at this point in the history
Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
  • Loading branch information
3 people authored and sfc-gh-mkeralapura committed Aug 12, 2024
1 parent 2f1fd0b commit b1735b0
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 37 deletions.
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from collections import UserList
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union

import pytest
import torch
Expand Down Expand Up @@ -508,7 +508,8 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
Expand Down
146 changes: 133 additions & 13 deletions tests/models/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@

pytestmark = pytest.mark.vlm


class NestedInputs(UserDict):

def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})

self.model_inputs = model_inputs

def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))


# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
Expand All @@ -23,7 +35,7 @@
"cherry_blossom":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"<|start_header_id|>assistant<|end_header_id|>\n\n",
})

models = ["openbmb/MiniCPM-Llama3-V-2_5"]
Expand Down Expand Up @@ -94,22 +106,10 @@ def run_test(
]

with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():

class NestedInputs(UserDict):

def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})

self.model_inputs = model_inputs

def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))

hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)

hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
Expand Down Expand Up @@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


HF_MULTIIMAGE_IMAGE_PROMPT = \
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\n(<image>./</image>)\n" \
"Describe these images.<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"


def run_multi_image_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
stop_token_ids=stop_token_ids)
for prompts, images in inputs_per_case
]

with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
tokenizer=tokenizer)
for prompts, images in inputs_per_case
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=[
trunc_hf_output(hf_output) for hf_output in hf_outputs
],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_multi_image_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
56 changes: 35 additions & 21 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,20 @@ def forward(self, x: torch.Tensor,
return x


def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)

# The old configs do not include version number
# TODO: Remove this after the HF repos are updated
if version_float is None:
if config.hidden_size == 2304 and config.query_num == 64:
return (2, 0)
return (2, 5)

version_str = str(version_float)
return tuple(int(x) for x in version_str.split("."))


def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
return getattr(hf_config, "query_num", 64)
Expand Down Expand Up @@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs

model_config = ctx.model_config

version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
image_processor = cached_get_image_processor(model_config.tokenizer)

def get_placeholder(image_size: Tuple[int, int], num_image: int):
if version == (2, 0) or version == (2, 5):
return image_processor. \
get_slice_image_placeholder(image_size)
return image_processor. \
get_slice_image_placeholder(image_size, num_image)

prompt = llm_inputs.get("prompt")
if prompt is None:
token_ids = llm_inputs.get("prompt_token_ids")
prompt = tokenizer.decode(token_ids)
image_processor = cached_get_image_processor(model_config.tokenizer)

pattern = "(<image>./</image>)"
image = multi_modal_data["image"]
images = multi_modal_data["image"]
if isinstance(images, Image.Image):
images = [images]
image_tags = re.findall(pattern, prompt)

if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if len(image_tags) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")

text_chunks = prompt.split(pattern)
new_prompt = (text_chunks[0] +
image_processor.get_slice_image_placeholder(image.size) +
"".join(text_chunks[1:]))

new_prompt_chunks: List[str] = []
for i in range(len(images)):
new_prompt_chunks += [
text_chunks[i],
get_placeholder(images[i].size, i)
]
new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)

llm_inputs = LLMInputs(
Expand Down Expand Up @@ -478,14 +499,7 @@ def __init__(
self.config = config
self.multimodal_config = multimodal_config

if not hasattr(self.config, "version"):
if self.config.hidden_size == 2304 and self.config.query_num == 64:
self.version = (2, 0)
else:
self.version = (2, 5)
else:
self.version = str(self.config.version).split(".")
self.version = tuple([int(x) for x in self.version])
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _get_hf_image_processor(self, model_config: ModelConfig):
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
model_config = ctx.model_config
if isinstance(data, Image.Image):
if isinstance(data, (Image.Image, list)):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
Expand Down

0 comments on commit b1735b0

Please sign in to comment.