diff --git a/tests/models/fixtures/pixtral_chat.pickle b/tests/models/fixtures/pixtral_chat.pickle new file mode 100644 index 0000000000000..43d4c883c3a49 Binary files /dev/null and b/tests/models/fixtures/pixtral_chat.pickle differ diff --git a/tests/models/fixtures/pixtral_chat_engine.pickle b/tests/models/fixtures/pixtral_chat_engine.pickle new file mode 100644 index 0000000000000..19dbeaecc8dff Binary files /dev/null and b/tests/models/fixtures/pixtral_chat_engine.pickle differ diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py index dc60cf7eae8b1..62ccaf1b79522 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/test_pixtral.py @@ -2,13 +2,92 @@ Run `pytest tests/models/test_mistral.py`. """ +import pickle +import uuid +from typing import Any, Dict, List + import pytest +from mistral_common.protocol.instruct.messages import ImageURLChunk +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.multimodal import image_from_chunk + +from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt +from vllm.multimodal import MultiModalDataBuiltins -from vllm.sampling_params import SamplingParams +from .utils import check_logprobs_close pytestmark = pytest.mark.vlm MODELS = ["mistralai/Pixtral-12B-2409"] +IMG_URLS = [ + "https://picsum.photos/id/237/400/300", + "https://picsum.photos/id/231/200/300", + "https://picsum.photos/id/27/500/500", + "https://picsum.photos/id/17/150/600", +] +PROMPT = "Describe each image in one short sentence." + + +def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "text": PROMPT, + }] + [{ + "type": "image_url", + "image_url": { + "url": url + } + } for url in urls], + }] + + +def _create_engine_inputs(urls: List[str]) -> TokensPrompt: + msg = _create_msg_format(urls) + + tokenizer = MistralTokenizer.from_model("pixtral") + + request = ChatCompletionRequest(messages=msg) # type: ignore[type-var] + tokenized = tokenizer.encode_chat_completion(request) + + engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens) + + images = [] + for chunk in request.messages[0].content: + if isinstance(chunk, ImageURLChunk): + images.append(image_from_chunk(chunk)) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs["multi_modal_data"] = mm_data + + return engine_inputs + + +MSGS = [ + _create_msg_format(IMG_URLS[:1]), + _create_msg_format(IMG_URLS[:2]), + _create_msg_format(IMG_URLS), +] +ENGINE_INPUTS = [ + _create_engine_inputs(IMG_URLS[:1]), + _create_engine_inputs(IMG_URLS[:2]), + _create_engine_inputs(IMG_URLS), +] + +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +LIMIT_MM_PER_PROMPT = dict(image=4) + +MAX_MODEL_LEN = [8192, 65536] +FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle" +FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle" + + +def load_logprobs(filename: str) -> Any: + with open(filename, 'rb') as f: + return pickle.load(f) @pytest.mark.skip( @@ -16,49 +95,74 @@ "Model is too big, test passed on A100 locally but will OOM on CI machine." ) @pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( +def test_chat( vllm_runner, - example_prompts, + max_model_len: int, model: str, dtype: str, - max_tokens: int, - num_logprobs: int, ) -> None: - image_urls = [ - "https://picsum.photos/id/237/200/300", - "https://picsum.photos/seed/picsum/200/300" - ] - expected = [ - "The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa - "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa - ] - prompt = "Describe the image in one short sentence." - - sampling_params = SamplingParams(max_tokens=512, temperature=0.0) - - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: - - for i, image_url in enumerate(image_urls): - messages = [ - { - "role": - "user", - "content": [{ - "type": "text", - "text": prompt - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }, - ] - - outputs = vllm_model.model.chat(messages, - sampling_params=sampling_params) - assert outputs[0].outputs[0].text == expected[i] + EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT) + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = [] + for msg in MSGS: + output = vllm_model.model.chat(msg, + sampling_params=SAMPLING_PARAMS) + + outputs.extend(output) + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=logprobs, + outputs_1_lst=EXPECTED_CHAT_LOGPROBS, + name_0="output", + name_1="h100_ref") + + +@pytest.mark.skip( + reason= + "Model is too big, test passed on A100 locally but will OOM on CI machine." +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_model_engine(vllm_runner, model: str, dtype: str) -> None: + EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE) + args = EngineArgs( + model=model, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + dtype=dtype, + ) + engine = LLMEngine.from_engine_args(args) + + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS) + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS) + + outputs = [] + count = 0 + while True: + out = engine.step() + count += 1 + for request_output in out: + if request_output.finished: + outputs.append(request_output) + + if count == 2: + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2], + SAMPLING_PARAMS) + if not engine.has_unfinished_requests(): + break + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=logprobs, + outputs_1_lst=EXPECTED_ENGINE_LOGPROBS, + name_0="output", + name_1="h100_ref") diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 010cf85f45e07..b26fd558fa1ea 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,3 @@ -import math from array import array from dataclasses import dataclass, fields from itertools import tee @@ -15,11 +14,12 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) - mm_encoder = tokenizer.instruct.mm_encoder - mm_config = ctx.model_config.multimodal_config - max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + patch_size = mm_encoder.mm_config.image_patch_size + image_token_id = mm_encoder.special_ids.img - # approximate image size - size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) + mm_config = ctx.model_config.multimodal_config + num_images = mm_config.limit_per_prompt.get("image", 1) + # dummy size + size = 256 image = Image.new("RGB", (size, size), color=0) - img_chunk = ImageChunk(image=image) - tokens = mm_encoder(img_chunk).tokens - token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, - tokens) + image_feature_size = (size**2) // (patch_size**2) + + num_image_tokens = image_feature_size * num_images + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * num_image_tokens + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - num_image_tokens) seq_data = SequenceData(token_ids) - mm_data = {"image": max_num_images_per_request * [image]} + mm_data = {"image": num_images * [image]} return seq_data, mm_data @@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def merge_multimodal_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: Optional[List[torch.Tensor]], - image_id: int) -> torch.Tensor: - text_locations = input_ids != image_id - image_locations = input_ids == image_id - - seq_len = input_ids.shape[0] +def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is not None and "image" in multi_modal_data: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - N_txt = text_locations.sum().item() - _, D_txt = inputs_embeds.shape - N_img, D_img = image_features.shape + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img - assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " - "to image features dim {D_img}") - assert (seq_len == N_txt + - N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " - f"{(N_txt, N_img, image_locations.sum().item())}") + if image_token_id not in llm_inputs['prompt_token_ids']: + raise ValueError( + (f"You've passed {llm_inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.")) - inputs_embeds[image_locations, :] = image_features - return inputs_embeds + return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) +@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, @@ -201,11 +206,21 @@ def _parse_and_validate_image_input( return None if isinstance(images, torch.Tensor): - # always take last images - images = [images[-1][i] for i in range(images.size(1))] + # if passed as batch take all images + N, B, C, W, H = images.shape + images = images.reshape(N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] elif isinstance(images, list): - # always take last images - images = [images[-1][i] for i in range(len(images[0]))] + # if passed as list flatten lists of tensors + flatten_images = [] + for imgs_per_req in images: + imgs_per_req = [ + imgs_per_req[i] for i in range(imgs_per_req.size(0)) + ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req + + flatten_images.extend(imgs_per_req) + + images = flatten_images return images