Skip to content

Commit

Permalink
[Hotfix][Pixtral] Fix multiple images bugs (vllm-project#8415)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
patrickvonplaten authored and garg-amit committed Oct 28, 2024
1 parent 883ae2a commit 7328512
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 76 deletions.
Binary file added tests/models/fixtures/pixtral_chat.pickle
Binary file not shown.
Binary file added tests/models/fixtures/pixtral_chat_engine.pickle
Binary file not shown.
188 changes: 146 additions & 42 deletions tests/models/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,167 @@
Run `pytest tests/models/test_mistral.py`.
"""
import pickle
import uuid
from typing import Any, Dict, List

import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk

from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins

from vllm.sampling_params import SamplingParams
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

MODELS = ["mistralai/Pixtral-12B-2409"]
IMG_URLS = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300",
"https://picsum.photos/id/27/500/500",
"https://picsum.photos/id/17/150/600",
]
PROMPT = "Describe each image in one short sentence."


def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
return [{
"role":
"user",
"content": [{
"type": "text",
"text": PROMPT,
}] + [{
"type": "image_url",
"image_url": {
"url": url
}
} for url in urls],
}]


def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
msg = _create_msg_format(urls)

tokenizer = MistralTokenizer.from_model("pixtral")

request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
tokenized = tokenizer.encode_chat_completion(request)

engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)

images = []
for chunk in request.messages[0].content:
if isinstance(chunk, ImageURLChunk):
images.append(image_from_chunk(chunk))

mm_data = MultiModalDataBuiltins(image=images)
engine_inputs["multi_modal_data"] = mm_data

return engine_inputs


MSGS = [
_create_msg_format(IMG_URLS[:1]),
_create_msg_format(IMG_URLS[:2]),
_create_msg_format(IMG_URLS),
]
ENGINE_INPUTS = [
_create_engine_inputs(IMG_URLS[:1]),
_create_engine_inputs(IMG_URLS[:2]),
_create_engine_inputs(IMG_URLS),
]

SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4)

MAX_MODEL_LEN = [8192, 65536]
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"


def load_logprobs(filename: str) -> Any:
with open(filename, 'rb') as f:
return pickle.load(f)


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
def test_chat(
vllm_runner,
example_prompts,
max_model_len: int,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
image_urls = [
"https://picsum.photos/id/237/200/300",
"https://picsum.photos/seed/picsum/200/300"
]
expected = [
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
]
prompt = "Describe the image in one short sentence."

sampling_params = SamplingParams(max_tokens=512, temperature=0.0)

with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:

for i, image_url in enumerate(image_urls):
messages = [
{
"role":
"user",
"content": [{
"type": "text",
"text": prompt
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}]
},
]

outputs = vllm_model.model.chat(messages,
sampling_params=sampling_params)
assert outputs[0].outputs[0].text == expected[i]
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
outputs = []
for msg in MSGS:
output = vllm_model.model.chat(msg,
sampling_params=SAMPLING_PARAMS)

outputs.extend(output)

logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
name_0="output",
name_1="h100_ref")


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
dtype=dtype,
)
engine = LLMEngine.from_engine_args(args)

engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)

outputs = []
count = 0
while True:
out = engine.step()
count += 1
for request_output in out:
if request_output.finished:
outputs.append(request_output)

if count == 2:
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
SAMPLING_PARAMS)
if not engine.has_unfinished_requests():
break

logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
name_0="output",
name_1="h100_ref")
83 changes: 49 additions & 34 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from array import array
from dataclasses import dataclass, fields
from itertools import tee
Expand All @@ -15,11 +14,12 @@

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
Expand Down Expand Up @@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder

mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img

# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
mm_config = ctx.model_config.multimodal_config
num_images = mm_config.limit_per_prompt.get("image", 1)

# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)

tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
image_feature_size = (size**2) // (patch_size**2)

num_image_tokens = image_feature_size * num_images

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)

seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
mm_data = {"image": num_images * [image]}
return seq_data, mm_data


Expand Down Expand Up @@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return MultiModalInputs({"images": images})


def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: Optional[List[torch.Tensor]],
image_id: int) -> torch.Tensor:
text_locations = input_ids != image_id
image_locations = input_ids == image_id

seq_len = input_ids.shape[0]
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)

N_txt = text_locations.sum().item()
_, D_txt = inputs_embeds.shape
N_img, D_img = image_features.shape
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img

assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
if image_token_id not in llm_inputs['prompt_token_ids']:
raise ValueError(
(f"You've passed {llm_inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."))

inputs_embeds[image_locations, :] = image_features
return inputs_embeds
return llm_inputs


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):

def __init__(self,
Expand Down Expand Up @@ -201,11 +206,21 @@ def _parse_and_validate_image_input(
return None

if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# always take last images
images = [images[-1][i] for i in range(len(images[0]))]
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req

flatten_images.extend(imgs_per_req)

images = flatten_images

return images

Expand Down

0 comments on commit 7328512

Please sign in to comment.