From 5231f0898e559671c6c8cc48efc53a859fce1841 Mon Sep 17 00:00:00 2001
From: Roger Wang <136131678+ywang96@users.noreply.github.com>
Date: Sat, 31 Aug 2024 16:35:53 -0700
Subject: [PATCH] [Frontend][VLM] Add support for multiple multi-modal items
(#8049)
---
.buildkite/test-pipeline.yaml | 1 +
examples/openai_vision_api_client.py | 39 +++
tests/entrypoints/openai/test_serving_chat.py | 2 +
tests/entrypoints/openai/test_vision.py | 71 ++--
tests/entrypoints/test_chat_utils.py | 305 ++++++++++++++++++
vllm/entrypoints/chat_utils.py | 228 +++++++------
vllm/entrypoints/openai/serving_chat.py | 10 +-
.../openai/serving_tokenization.py | 4 +-
8 files changed, 524 insertions(+), 136 deletions(-)
create mode 100644 tests/entrypoints/test_chat_utils.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 235db72eee4b9..86eddb576c42a 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -90,6 +90,7 @@ steps:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/openai
+ - pytest -v -s entrypoints/test_chat_utils.py
- label: Distributed Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py
index be90394511f89..e1d4055763e5f 100644
--- a/examples/openai_vision_api_client.py
+++ b/examples/openai_vision_api_client.py
@@ -1,7 +1,13 @@
"""An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command:
+
+(single image inference with Llava)
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
+
+(multi-image inference with Phi-3.5-vision-instruct)
+vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
+ --trust-remote-code --limit-mm-per-prompt image=2
"""
import base64
@@ -84,3 +90,36 @@ def encode_image_base64_from_url(image_url: str) -> str:
result = chat_completion_from_base64.choices[0].message.content
print(f"Chat completion output:{result}")
+
+# Multi-image input inference
+image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
+image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
+chat_completion_from_url = client.chat.completions.create(
+ messages=[{
+ "role":
+ "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the animals in these images?"
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url_duck
+ },
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url_lion
+ },
+ },
+ ],
+ }],
+ model=model,
+ max_tokens=64,
+)
+
+result = chat_completion_from_url.choices[0].message.content
+print(f"Chat completion output:{result}")
diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py
index 3783b7cd66a6a..c3a6c65be1d90 100644
--- a/tests/entrypoints/openai/test_serving_chat.py
+++ b/tests/entrypoints/openai/test_serving_chat.py
@@ -3,6 +3,7 @@
from dataclasses import dataclass
from unittest.mock import MagicMock
+from vllm.config import MultiModalConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@@ -20,6 +21,7 @@ class MockModelConfig:
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
+ multimodal_config = MultiModalConfig()
@dataclass
diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py
index d2ef3c2071efb..f61fa127b7d06 100644
--- a/tests/entrypoints/openai/test_vision.py
+++ b/tests/entrypoints/openai/test_vision.py
@@ -6,11 +6,10 @@
from vllm.multimodal.utils import encode_image_base64, fetch_image
-from ...utils import VLLM_PATH, RemoteOpenAIServer
+from ...utils import RemoteOpenAIServer
-MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
-LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja"
-assert LLAVA_CHAT_TEMPLATE.exists()
+MODEL_NAME = "microsoft/Phi-3.5-vision-instruct"
+MAXIMUM_IMAGES = 2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
@@ -24,13 +23,9 @@
@pytest.fixture(scope="module")
def server():
args = [
- "--dtype",
- "bfloat16",
- "--max-model-len",
- "4096",
- "--enforce-eager",
- "--chat-template",
- str(LLAVA_CHAT_TEMPLATE),
+ "--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
+ "5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
+ f"image={MAXIMUM_IMAGES}"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -84,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
- completion_tokens=10, prompt_tokens=596, total_tokens=606)
+ completion_tokens=10, prompt_tokens=772, total_tokens=782)
message = choice.message
message = chat_completion.choices[0].message
@@ -139,7 +134,7 @@ async def test_single_chat_session_image_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
- completion_tokens=10, prompt_tokens=596, total_tokens=606)
+ completion_tokens=10, prompt_tokens=772, total_tokens=782)
message = choice.message
message = chat_completion.choices[0].message
@@ -217,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
-@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+@pytest.mark.parametrize(
+ "image_urls",
+ [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
- image_url: str):
+ image_urls: List[str]):
messages = [{
"role":
"user",
"content": [
- {
- "type": "image_url",
- "image_url": {
- "url": image_url
- }
- },
- {
+ *({
"type": "image_url",
"image_url": {
"url": image_url
}
- },
+ } for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
@@ -244,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
],
}]
- with pytest.raises(openai.BadRequestError): # test multi-image input
- await client.chat.completions.create(
+ if len(image_urls) > MAXIMUM_IMAGES:
+ with pytest.raises(openai.BadRequestError): # test multi-image input
+ await client.chat.completions.create(
+ model=model_name,
+ messages=messages,
+ max_tokens=10,
+ temperature=0.0,
+ )
+
+ # the server should still work afterwards
+ completion = await client.completions.create(
+ model=model_name,
+ prompt=[0, 0, 0, 0, 0],
+ max_tokens=5,
+ temperature=0.0,
+ )
+ completion = completion.choices[0].text
+ assert completion is not None and len(completion) >= 0
+ else:
+ chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
-
- # the server should still work afterwards
- completion = await client.completions.create(
- model=model_name,
- prompt=[0, 0, 0, 0, 0],
- max_tokens=5,
- temperature=0.0,
- )
- completion = completion.choices[0].text
- assert completion is not None and len(completion) >= 0
+ message = chat_completion.choices[0].message
+ assert message.content is not None and len(message.content) >= 0
diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py
new file mode 100644
index 0000000000000..53f99189beb1c
--- /dev/null
+++ b/tests/entrypoints/test_chat_utils.py
@@ -0,0 +1,305 @@
+import warnings
+
+import pytest
+from PIL import Image
+
+from vllm.assets.image import ImageAsset
+from vllm.config import ModelConfig
+from vllm.entrypoints.chat_utils import parse_chat_messages
+from vllm.multimodal.utils import encode_image_base64
+from vllm.transformers_utils.tokenizer_group import TokenizerGroup
+
+PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
+
+
+@pytest.fixture(scope="module")
+def phi3v_model_config():
+ return ModelConfig(PHI3V_MODEL_ID,
+ PHI3V_MODEL_ID,
+ tokenizer_mode="auto",
+ trust_remote_code=True,
+ dtype="bfloat16",
+ seed=0,
+ limit_mm_per_prompt={
+ "image": 2,
+ })
+
+
+@pytest.fixture(scope="module")
+def phi3v_tokenizer():
+ return TokenizerGroup(
+ tokenizer_id=PHI3V_MODEL_ID,
+ enable_lora=False,
+ max_num_seqs=5,
+ max_input_length=None,
+ )
+
+
+@pytest.fixture(scope="module")
+def image_url():
+ image = ImageAsset('cherry_blossom')
+ base64 = encode_image_base64(image.pil_image)
+ return f"data:image/jpeg;base64,{base64}"
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_with_image_url(phi3v_model_config,
+ phi3v_tokenizer, image_url):
+ conversation, mm_future = parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What's in the image?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
+
+ assert conversation == [{
+ "role": "user",
+ "content": "<|image_1|>\nWhat's in the image?"
+ }]
+ mm_data = await mm_future
+ assert set(mm_data.keys()) == {"image"}
+ assert isinstance(mm_data["image"], Image.Image)
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_multiple_images(phi3v_model_config,
+ phi3v_tokenizer, image_url):
+ conversation, mm_future = parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What's in these images?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
+
+ assert conversation == [{
+ "role":
+ "user",
+ "content":
+ "<|image_1|>\n<|image_2|>\nWhat's in these images?"
+ }]
+ mm_data = await mm_future
+ assert set(mm_data.keys()) == {"image"}
+ assert len(mm_data["image"]) == 2
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_placeholder_already_in_prompt(
+ phi3v_model_config, phi3v_tokenizer, image_url):
+ conversation, mm_future = parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type":
+ "text",
+ "text":
+ "What's in <|image_1|> and how does it compare to <|image_2|>?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
+
+ assert conversation == [{
+ "role":
+ "user",
+ "content":
+ "What's in <|image_1|> and how does it compare to <|image_2|>?"
+ }]
+ mm_data = await mm_future
+ assert set(mm_data.keys()) == {"image"}
+ assert len(mm_data["image"]) == 2
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_placeholder_one_already_in_prompt(
+ phi3v_model_config, phi3v_tokenizer, image_url):
+ conversation, mm_future = parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type":
+ "text",
+ "text":
+ "What's in <|image_1|> and how does it compare to the other one?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
+
+ assert conversation == [{
+ "role":
+ "user",
+ "content":
+ "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
+ "other one?"
+ }]
+ mm_data = await mm_future
+ assert set(mm_data.keys()) == {"image"}
+ assert len(mm_data["image"]) == 2
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_multiple_images_across_messages(
+ phi3v_model_config, phi3v_tokenizer, image_url):
+ conversation, mm_future = parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What's in this image?"
+ }]
+ }, {
+ "role": "assistant",
+ "content": "Some stuff."
+ }, {
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What about this one?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
+
+ assert conversation == [
+ {
+ "role": "user",
+ "content": "<|image_1|>\nWhat's in this image?"
+ },
+ {
+ "role": "assistant",
+ "content": "Some stuff."
+ },
+ {
+ "role": "user",
+ "content": "<|image_2|>\nWhat about this one?"
+ },
+ ]
+ mm_data = await mm_future
+ assert set(mm_data.keys()) == {"image"}
+ assert len(mm_data["image"]) == 2
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
+ phi3v_model_config, phi3v_tokenizer, image_url):
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore",
+ message="coroutine 'async_get_and_parse_image' was never awaited")
+ with pytest.raises(
+ ValueError,
+ match="At most 2 image\\(s\\) may be provided in one request\\."
+ ):
+ parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What's in these images?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
+
+
+@pytest.mark.asyncio
+async def test_parse_chat_messages_rejects_too_many_images_across_messages(
+ phi3v_model_config, phi3v_tokenizer, image_url):
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore",
+ message="coroutine 'async_get_and_parse_image' was never awaited")
+ with pytest.raises(
+ ValueError,
+ match="At most 2 image\\(s\\) may be provided in one request\\."
+ ):
+ parse_chat_messages([{
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What's in this image?"
+ }]
+ }, {
+ "role": "assistant",
+ "content": "Some stuff."
+ }, {
+ "role":
+ "user",
+ "content": [{
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ }, {
+ "type": "text",
+ "text": "What about these two?"
+ }]
+ }], phi3v_model_config, phi3v_tokenizer)
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index c5368ac3bf026..c70c6d9330b10 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -1,9 +1,10 @@
+import asyncio
import codecs
-from dataclasses import dataclass
+from collections import defaultdict
from functools import lru_cache
from pathlib import Path
-from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
- Union)
+from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
+ Optional, Tuple, Union)
# yapf conflicts with isort for this block
# yapf: disable
@@ -80,10 +81,90 @@ class ConversationMessage(TypedDict):
content: str
-@dataclass(frozen=True)
-class ChatMessageParseResult:
- messages: List[ConversationMessage]
- mm_futures: List[Awaitable[MultiModalDataDict]]
+class MultiModalItemTracker:
+ """
+ Tracks multi-modal items in a given request and ensures that the number
+ of multi-modal items in a given request does not exceed the configured
+ maximum per prompt.
+ """
+
+ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
+ self._model_config = model_config
+ self._tokenizer = tokenizer
+ self._allowed_items = (model_config.multimodal_config.limit_per_prompt
+ if model_config.multimodal_config else {})
+ self._consumed_items = {k: 0 for k in self._allowed_items}
+ self._futures: List[Awaitable[MultiModalDataDict]] = []
+
+ @staticmethod
+ @lru_cache(maxsize=None)
+ def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
+ return tokenizer.decode(token_index)
+
+ def add(self, modality: Literal["image", "audio"],
+ mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
+ """
+ Adds the multi-modal item to the current prompt and returns the
+ placeholder string to use, if any.
+ """
+ allowed_count = self._allowed_items.get(modality, 1)
+ current_count = self._consumed_items.get(modality, 0) + 1
+ if current_count > allowed_count:
+ raise ValueError(
+ f"At most {allowed_count} {modality}(s) may be provided in "
+ "one request.")
+
+ self._consumed_items[modality] = current_count
+ self._futures.append(mm_future)
+
+ # TODO: Let user specify how to insert image tokens into prompt
+ # (similar to chat template)
+ model_type = self._model_config.hf_config.model_type
+ if modality == "image":
+ if model_type == "phi3_v":
+ # Workaround since this token is not defined in the tokenizer
+ return f"<|image_{current_count}|>"
+ if model_type == "minicpmv":
+ return "(./)"
+ if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
+ # These models do not use image tokens in the prompt
+ return None
+ if model_type.startswith("llava"):
+ return MultiModalItemTracker._cached_token_str(
+ self._tokenizer,
+ self._model_config.hf_config.image_token_index)
+ if model_type in ("chameleon", "internvl_chat"):
+ return ""
+
+ raise TypeError(f"Unknown model type: {model_type}")
+ elif modality == "audio":
+ if model_type == "ultravox":
+ return "<|reserved_special_token_0|>"
+ raise TypeError(f"Unknown model type: {model_type}")
+ else:
+ raise TypeError(f"Unknown modality: {modality}")
+
+ @staticmethod
+ async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
+ mm_lists: Mapping[str, List[object]] = defaultdict(list)
+
+ # Merge all the multi-modal items
+ for single_mm_data in (await asyncio.gather(*futures)):
+ for mm_key, mm_item in single_mm_data.items():
+ if isinstance(mm_item, list):
+ mm_lists[mm_key].extend(mm_item)
+ else:
+ mm_lists[mm_key].append(mm_item)
+
+ # Unpack any single item lists for models that don't expect multiple.
+ return {
+ mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
+ for mm_key, mm_list in mm_lists.items()
+ }
+
+ def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
+ return MultiModalItemTracker._combine(
+ self._futures) if self._futures else None
def load_chat_template(
@@ -112,44 +193,30 @@ def load_chat_template(
return resolved_chat_template
-@lru_cache(maxsize=None)
-def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
- modality: Literal["image", "audio"]) -> Optional[str]:
- # TODO: Let user specify how to insert image tokens into prompt
- # (similar to chat template)
- model_type = model_config.hf_config.model_type
- if modality == "image":
- if model_type == "phi3_v":
- # Workaround since this token is not defined in the tokenizer
- return "<|image_1|>"
- if model_type == "minicpmv":
- return "(./)"
- if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
- # These models do not use image tokens in the prompt
- return None
- if model_type.startswith("llava"):
- return tokenizer.decode(model_config.hf_config.image_token_index)
- if model_type in ("chameleon", "internvl_chat"):
- return ""
-
- raise TypeError(f"Unknown model type: {model_type}")
- elif modality == "audio":
- if model_type == "ultravox":
- return "<|reserved_special_token_0|>"
- raise TypeError(f"Unknown model type: {model_type}")
- else:
- raise TypeError(f"Unknown modality: {modality}")
-
-
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
-def _get_full_multimodal_text_prompt(placeholder_token_str: str,
+def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
- # NOTE: For now we assume all model architectures use the same
- # placeholder + text prompt format. This may change in the future.
- return f"{placeholder_token_str}\n{text_prompt}"
+ # Look through the text prompt to check for missing placeholders
+ missing_placeholders = []
+ for placeholder in placeholder_counts:
+
+ # For any existing placeholder in the text prompt, we leave it as is
+ placeholder_counts[placeholder] -= text_prompt.count(placeholder)
+
+ if placeholder_counts[placeholder] < 0:
+ raise ValueError(
+ f"Found more '{placeholder}' placeholders in input prompt than "
+ "actual multimodal data items.")
+
+ missing_placeholders.extend([placeholder] *
+ placeholder_counts[placeholder])
+
+ # NOTE: For now we always add missing placeholders at the front of
+ # the prompt. This may change to be customizable in the future.
+ return "\n".join(missing_placeholders + [text_prompt])
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
@@ -160,12 +227,12 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str,
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
- model_config: ModelConfig,
- tokenizer: AnyTokenizer,
-) -> ChatMessageParseResult:
+ mm_tracker: MultiModalItemTracker,
+) -> List[ConversationMessage]:
texts: List[str] = []
- mm_futures: List[Awaitable[MultiModalDataDict]] = []
- modality: Literal["image", "audio"] = "image"
+
+ # multimodal placeholder_string : count
+ mm_placeholder_counts: Dict[str, int] = {}
for part in parts:
part_type = part["type"]
@@ -173,11 +240,6 @@ def _parse_chat_message_content_parts(
text = _TextParser.validate_python(part)["text"]
texts.append(text)
elif part_type == "image_url":
- modality = "image"
- if len(mm_futures) > 0:
- raise NotImplementedError(
- "Multiple multimodal inputs is currently not supported.")
-
image_url = _ImageParser.validate_python(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
@@ -185,60 +247,44 @@ def _parse_chat_message_content_parts(
"'image_url.detail' is currently not supported and "
"will be ignored.")
- image_future = async_get_and_parse_image(image_url["url"])
- mm_futures.append(image_future)
+ image_coro = async_get_and_parse_image(image_url["url"])
+ placeholder = mm_tracker.add("image", image_coro)
+ if placeholder:
+ mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
+ placeholder, 0) + 1
elif part_type == "audio_url":
- modality = "audio"
- if len(mm_futures) > 0:
- raise NotImplementedError(
- "Multiple multimodal inputs is currently not supported.")
-
audio_url = _AudioParser.validate_python(part)["audio_url"]
- audio_future = async_get_and_parse_audio(audio_url["url"])
- mm_futures.append(audio_future)
+ audio_coro = async_get_and_parse_audio(audio_url["url"])
+ placeholder = mm_tracker.add("audio", audio_coro)
+ if placeholder:
+ mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
+ placeholder, 0) + 1
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
+ if mm_placeholder_counts:
+ text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
+ text_prompt)
- if mm_futures:
- placeholder_token_str = _mm_token_str(model_config, tokenizer,
- modality)
- if placeholder_token_str is not None:
- if placeholder_token_str in text_prompt:
- logger.warning(
- "Detected multi-modal token string in the text prompt. "
- "Skipping prompt formatting.")
- else:
- text_prompt = _get_full_multimodal_text_prompt(
- placeholder_token_str=placeholder_token_str,
- text_prompt=text_prompt,
- )
-
- messages = [ConversationMessage(role=role, content=text_prompt)]
-
- return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
+ return [ConversationMessage(role=role, content=text_prompt)]
def _parse_chat_message_content(
- message: ChatCompletionMessageParam,
- model_config: ModelConfig,
- tokenizer: AnyTokenizer,
-) -> ChatMessageParseResult:
+ message: ChatCompletionMessageParam,
+ mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
if content is None:
- return ChatMessageParseResult(messages=[], mm_futures=[])
+ return []
if isinstance(content, str):
- messages = [ConversationMessage(role=role, content=content)]
- return ChatMessageParseResult(messages=messages, mm_futures=[])
+ return [ConversationMessage(role=role, content=content)]
return _parse_chat_message_content_parts(
role,
content, # type: ignore
- model_config,
- tokenizer,
+ mm_tracker,
)
@@ -246,18 +292,16 @@ def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
-) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
+) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
- mm_futures: List[Awaitable[MultiModalDataDict]] = []
+ mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
- parse_result = _parse_chat_message_content(msg, model_config,
- tokenizer)
+ sub_messages = _parse_chat_message_content(msg, mm_tracker)
- conversation.extend(parse_result.messages)
- mm_futures.extend(parse_result.mm_futures)
+ conversation.extend(sub_messages)
- return conversation, mm_futures
+ return conversation, mm_tracker.all_mm_data()
def apply_chat_template(
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index d31ac4995fe2f..f7576509d06c8 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -94,7 +94,7 @@ async def create_chat_completion(
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
- conversation, mm_futures = parse_chat_messages(
+ conversation, mm_data_future = parse_chat_messages(
request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [
@@ -116,12 +116,8 @@ async def create_chat_completion(
mm_data: Optional[MultiModalDataDict] = None
try:
- if len(mm_futures):
- # since we support only single mm data currently
- assert len(
- mm_futures
- ) == 1, "Multiple 'image_url' input is currently not supported."
- mm_data = await mm_futures[0]
+ if mm_data_future:
+ mm_data = await mm_data_future
except Exception as e:
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py
index 1aeabb7a7d729..fc9ca29e9cf86 100644
--- a/vllm/entrypoints/openai/serving_tokenization.py
+++ b/vllm/entrypoints/openai/serving_tokenization.py
@@ -65,10 +65,10 @@ async def create_tokenize(
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
- conversation, mm_futures = parse_chat_messages(
+ conversation, mm_data_future = parse_chat_messages(
request.messages, model_config, tokenizer)
- if mm_futures:
+ if mm_data_future:
logger.warning(
"Multi-modal inputs are ignored during tokenization")