From 0d31364c344a23f0b01ab7bc969853c96b71d900 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 28 Jun 2024 20:09:56 +0800 Subject: [PATCH] [Core] Registry for processing model inputs (#5214) Co-authored-by: ywang96 --- .../input_processing_pipeline.rst | 20 ++ .../input_processing/model_inputs_index.rst | 39 ++++ .../dev/multimodal/multimodal_index.rst | 8 +- docs/source/index.rst | 1 + docs/source/models/adding_model.rst | 4 +- examples/phi3v_example.py | 3 +- .../{test_processor.py => test_mapper.py} | 69 +++--- vllm/config.py | 3 + vllm/engine/arg_utils.py | 64 +++--- vllm/engine/async_llm_engine.py | 8 +- vllm/engine/llm_engine.py | 13 +- vllm/inputs/__init__.py | 19 ++ vllm/{inputs.py => inputs/data.py} | 18 +- vllm/inputs/registry.py | 207 ++++++++++++++++++ vllm/model_executor/models/clip.py | 77 ++++++- vllm/model_executor/models/llava.py | 40 +++- vllm/model_executor/models/llava_next.py | 138 +++++++++--- vllm/model_executor/models/phi3v.py | 74 +++++-- vllm/multimodal/__init__.py | 11 +- vllm/multimodal/base.py | 84 ++++--- vllm/multimodal/image.py | 78 ++----- vllm/multimodal/registry.py | 133 ++++------- vllm/sequence.py | 4 +- vllm/transformers_utils/image_processor.py | 4 - vllm/worker/cpu_model_runner.py | 20 +- vllm/worker/model_runner.py | 31 +-- 26 files changed, 778 insertions(+), 392 deletions(-) create mode 100644 docs/source/dev/input_processing/input_processing_pipeline.rst create mode 100644 docs/source/dev/input_processing/model_inputs_index.rst rename tests/multimodal/{test_processor.py => test_mapper.py} (71%) create mode 100644 vllm/inputs/__init__.py rename vllm/{inputs.py => inputs/data.py} (90%) create mode 100644 vllm/inputs/registry.py diff --git a/docs/source/dev/input_processing/input_processing_pipeline.rst b/docs/source/dev/input_processing/input_processing_pipeline.rst new file mode 100644 index 0000000000000..e0c773781115f --- /dev/null +++ b/docs/source/dev/input_processing/input_processing_pipeline.rst @@ -0,0 +1,20 @@ +.. _input_processing_pipeline: + +Input Processing Pipeline +========================= + +1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`). + +2. Tokenize the data if necessary. + +3. Process the inputs using :meth:`INPUT_REGISTRY.process_input `. + + - For example, add placeholder tokens to reserve KV cache for multi-modal embeddings. + +4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`. + +5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`. + +6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input `. + + - For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model. diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst new file mode 100644 index 0000000000000..594edeb746bb4 --- /dev/null +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -0,0 +1,39 @@ +.. _input_processing: + +Input Processing +================ + +.. currentmodule:: vllm.inputs + +vLLM provides a mechanism for defining input processors for each model so that the inputs are processed +in :class:`~vllm.LLMEngine` before they are passed to model executors. + +Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input +data in addition to input prompt, but it can be extended to text-only language models when needed. + +Guides +++++++ + +.. toctree:: + :maxdepth: 1 + + input_processing_pipeline + +Module Contents ++++++++++++++++ + +LLM Engine Inputs +----------------- + +.. autoclass:: vllm.inputs.LLMInputs + :members: + :show-inheritance: + +Registry +-------- + +.. autodata:: vllm.inputs.INPUT_REGISTRY + +.. automodule:: vllm.inputs.registry + :members: + :show-inheritance: diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index a25eceecc276b..f6fdfc1debffb 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -12,10 +12,6 @@ By default, vLLM models do not support multi-modal inputs. To enable multi-modal you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data `, as well as :meth:`MULTIMODAL_REGISTRY.register_input ` for each modality type to support. -.. contents:: - :local: - :backlinks: none - Module Contents +++++++++++++++ @@ -24,9 +20,7 @@ Module Contents Registry -------- -.. data:: vllm.multimodal.MULTIMODAL_REGISTRY - - The global :class:`MultiModalRegistry` which is used by model runners. +.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY .. autoclass:: vllm.multimodal.MultiModalRegistry :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 05133eb6d867a..3a9f5a3d81e84 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -120,6 +120,7 @@ Documentation dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention + dev/input_processing/model_inputs_index dev/multimodal/multimodal_index dev/dockerfile/dockerfile diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index cbc8099e6f70f..f282b594590be 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -37,7 +37,7 @@ For instance, vLLM's `OPT model None: self.model = model self.tokenizer = tokenizer @@ -159,6 +160,8 @@ def __init__( sliding_window_len=self.get_hf_config_sliding_window()) self.served_model_name = get_served_model_name(model, served_model_name) + self.multimodal_config = multimodal_config + if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_embedding_mode() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 16374098b23d4..c392155e8981b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -643,6 +643,36 @@ def create_engine_config(self, ) -> EngineConfig: raise ValueError( "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") + if self.image_input_type: + if (not self.image_token_id or not self.image_input_shape + or not self.image_feature_size): + raise ValueError( + 'Specify `image_token_id`, `image_input_shape` and ' + '`image_feature_size` together with `image_input_type`.') + + if self.image_processor is None: + self.image_processor = self.model + if self.disable_image_processor: + if self.image_processor != self.model: + warnings.warn( + "You've specified an image processor " + f"({self.image_processor}) but also disabled " + "it via `--disable-image-processor`.", + stacklevel=2) + + self.image_processor = None + + vision_language_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig. + get_image_input_enum_type(self.image_input_type), + image_token_id=self.image_token_id, + image_input_shape=str_to_int_tuple(self.image_input_shape), + image_feature_size=self.image_feature_size, + image_processor=self.image_processor, + image_processor_revision=self.image_processor_revision, + ) + else: + vision_language_config = None device_config = DeviceConfig(device=self.device) model_config = ModelConfig( @@ -666,7 +696,8 @@ def create_engine_config(self, ) -> EngineConfig: max_logprobs=self.max_logprobs, disable_sliding_window=self.disable_sliding_window, skip_tokenizer_init=self.skip_tokenizer_init, - served_model_name=self.served_model_name) + served_model_name=self.served_model_name, + multimodal_config=vision_language_config) cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -742,37 +773,6 @@ def create_engine_config(self, ) -> EngineConfig: model_loader_extra_config=self.model_loader_extra_config, ) - if self.image_input_type: - if (not self.image_token_id or not self.image_input_shape - or not self.image_feature_size): - raise ValueError( - 'Specify `image_token_id`, `image_input_shape` and ' - '`image_feature_size` together with `image_input_type`.') - - if self.image_processor is None: - self.image_processor = self.model - if self.disable_image_processor: - if self.image_processor != self.model: - warnings.warn( - "You've specified an image processor " - f"({self.image_processor}) but also disabled " - "it via `--disable-image-processor`.", - stacklevel=2) - - self.image_processor = None - - vision_language_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig. - get_image_input_enum_type(self.image_input_type), - image_token_id=self.image_token_id, - image_input_shape=str_to_int_tuple(self.image_input_shape), - image_feature_size=self.image_feature_size, - image_processor=self.image_processor, - image_processor_revision=self.image_processor_revision, - ) - else: - vision_language_config = None - decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7994b873fe9bd..848e05f033a8e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -278,9 +278,11 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) + + return self.input_processor(llm_inputs) async def add_request_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4b427b1fb2f22..9b720d6138868 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,7 +20,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import LLMInputs, PromptInputs +from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -227,6 +227,9 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) + self.input_processor = INPUT_REGISTRY.create_input_processor( + self.model_config) + self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, @@ -511,9 +514,11 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) + + return self.input_processor(llm_inputs) def add_request( self, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py new file mode 100644 index 0000000000000..d094156962955 --- /dev/null +++ b/vllm/inputs/__init__.py @@ -0,0 +1,19 @@ +from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, + PromptStrictInputs, TextPrompt, TextTokensPrompt, + TokensPrompt, parse_and_batch_prompt) +from .registry import InputContext, InputRegistry + +INPUT_REGISTRY = InputRegistry() +""" +The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine` +to dispatch data processing according to the target model. + +See also: + :ref:`input_processing_pipeline` +""" + +__all__ = [ + "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", + "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", + "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" +] diff --git a/vllm/inputs.py b/vllm/inputs/data.py similarity index 90% rename from vllm/inputs.py rename to vllm/inputs/data.py index 026903e19a26e..9b163b9cfb666 100644 --- a/vllm/inputs.py +++ b/vllm/inputs/data.py @@ -101,8 +101,7 @@ class TextTokensPrompt(TypedDict): """The prompt text.""" prompt_token_ids: List[int] - """The token IDs of the prompt. If None, we use the - tokenizer to convert the prompts to token IDs.""" + """The token IDs of the prompt.""" multi_modal_data: NotRequired["MultiModalData"] """ @@ -125,6 +124,21 @@ class TextTokensPrompt(TypedDict): class LLMInputs(TypedDict): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + """ + prompt_token_ids: List[int] + """The token IDs of the prompt.""" + prompt: NotRequired[Optional[str]] + """ + The original prompt text corresponding to the token IDs, if available. + """ + multi_modal_data: NotRequired[Optional["MultiModalData"]] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py new file mode 100644 index 0000000000000..8f4e108b8cca5 --- /dev/null +++ b/vllm/inputs/registry.py @@ -0,0 +1,207 @@ +import functools +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, + TypeVar) + +from torch import nn +from transformers import PretrainedConfig + +from vllm.logger import init_logger + +from .data import LLMInputs + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VisionLanguageConfig + from vllm.multimodal import MultiModalData + from vllm.sequence import SequenceData + +logger = init_logger(__name__) + +C = TypeVar("C", bound=PretrainedConfig) + + +@dataclass(frozen=True) +class InputContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + def get_multimodal_config(self) -> "VisionLanguageConfig": + """ + Get the multimodal configuration of the model. + + Raises: + ValueError: If the model is not multimodal. + """ + + multimodal_config = self.model_config.multimodal_config + if multimodal_config is None: + raise ValueError("No multimodal config found") + + return multimodal_config + + def get_hf_config(self, hf_config_type: Type[C]) -> C: + """ + Get the HuggingFace configuration + (:class:`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + ValueError: If the model is not of the specified type. + """ + + hf_config = self.model_config.hf_config + if not isinstance(hf_config, hf_config_type): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {hf_config_type}, but " + f"found type: {type(hf_config)}") + + return hf_config + + +N = TypeVar("N", bound=Type[nn.Module]) + +DummyDataFactory = Callable[[InputContext, int], + Tuple["SequenceData", Optional["MultiModalData"]]] +""" +Create dummy data to be inputted into the model. + +Note: + :data:`InputProcessor` is not applied to the dummy data. +""" + +InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] +"""Preprocess the inputs to the model.""" + + +class InputRegistry: + """ + A registry to dispatch data processing + according to the target model. + """ + + def __init__(self) -> None: + self._dummy_factories_by_model_type: Dict[Type[nn.Module], + DummyDataFactory] = {} + self._input_processors_by_model_type: Dict[Type[nn.Module], + InputProcessor] = {} + + def _default_dummy_data_factory( + self, + ctx: InputContext, + seq_len: int, + ) -> Tuple["SequenceData", Optional["MultiModalData"]]: + """ + The default dummy data factory represents the longest possible text + that can be inputted to the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + """ + # Avoid circular import + from vllm.sequence import SequenceData + + dummy_seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + + return dummy_seq_data, dummy_multi_modal_data + + def register_dummy_data(self, factory: DummyDataFactory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The resulting memory usage + should be an upper bound of what the model would use at inference time. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_factories_by_model_type: + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def dummy_data_for_profiling(self, model_config: "ModelConfig", + seq_len: int): + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + + TODO: Add guide [ref: PR #5276] + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + dummy_factory = self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + return dummy_factory(InputContext(model_config), seq_len) + + def _default_input_processor(self, ctx: InputContext, + inputs: LLMInputs) -> LLMInputs: + """The default input processor is a no-op.""" + return inputs + + def register_input_processor(self, processor: InputProcessor): + """ + Register an input processor to a model class. + + The provided function is invoked on each input to the model. This + happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`. + + See also: + :ref:`input_processing_pipeline` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_processors_by_model_type: + logger.warning( + "Model class %s already has input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors_by_model_type[model_cls] = processor + + return model_cls + + return wrapper + + def process_input(self, model_config: "ModelConfig", + inputs: LLMInputs) -> LLMInputs: + """ + Apply an input processor to an instance of model inputs. + + The model is identified by ``model_config``. + + See also: + :ref:`input_processing_pipeline` + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + processor = self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + + return processor(InputContext(model_config), inputs) + + def create_input_processor(self, model_config: "ModelConfig"): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + return functools.partial(self.process_input, model_config) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index aa4e87228a7e4..77fbade056ee6 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,22 +1,83 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn +from PIL import Image from transformers import CLIPVisionConfig from transformers.models.clip.modeling_clip import CLIPAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal.image import ImageFeatureData, ImagePixelData +from vllm.sequence import SequenceData -def get_clip_num_patches(image_size: int, patch_size: int) -> int: +def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 - return (image_size // patch_size)**2 + return image_size // patch_size + + +def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_clip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: + return get_clip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def dummy_seq_data_for_clip( + hf_config: CLIPVisionConfig, + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_pixel_data_for_clip( + hf_config: CLIPVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return ImagePixelData(image) + + +def dummy_feature_data_for_clip( + hf_config: CLIPVisionConfig, + *, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + values = torch.zeros((1, image_feature_size, hf_config.hidden_size), + dtype=torch.float16) + return ImageFeatureData(values) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa @@ -39,8 +100,8 @@ def __init__(self, config: CLIPVisionConfig): bias=False, ) - self.num_patches = get_clip_num_patches(self.image_size, - self.patch_size) + self.num_patches = get_clip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) @@ -101,7 +162,7 @@ def __init__(self, self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 125e3ddea7df3..bdcb6331730ab 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -2,10 +2,11 @@ import torch import torch.nn as nn -from transformers import LlavaConfig +from transformers import CLIPVisionConfig, LlavaConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig +from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -16,10 +17,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import get_dummy_image_data +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData from vllm.sequence import SamplerOutput +from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip, + dummy_seq_data_for_clip) from .interfaces import SupportsVision _KEYS_TO_MODIFY_MAPPING = { @@ -83,9 +85,35 @@ class LlavaImageFeatureInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] -@MULTIMODAL_REGISTRY.register_image_feature_input() -@MULTIMODAL_REGISTRY.register_image_pixel_input() -@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) +def dummy_data_for_llava(ctx: InputContext, seq_len: int): + multimodal_config = ctx.get_multimodal_config() + hf_config = ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + ) + + image_input_type = multimodal_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + mm_data: MultiModalData + if image_input_type == ImageInputType.PIXEL_VALUES: + mm_data = dummy_pixel_data_for_clip(vision_config) + elif image_input_type == ImageInputType.IMAGE_FEATURES: + mm_data = dummy_feature_data_for_clip(vision_config) + + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_image_feature_input_mapper() +@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsVision): def __init__(self, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 841818d8db6ff..cebc828165ed7 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -3,14 +3,14 @@ import torch import torch.nn as nn -from PIL import Image -from transformers import LlavaNextConfig +from transformers import CLIPVisionConfig, LlavaNextConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig +from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -22,9 +22,11 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData -from vllm.multimodal.image import ImagePixelData, get_dummy_image_data -from vllm.sequence import SamplerOutput, SequenceData +from vllm.multimodal.image import ImagePixelData +from vllm.sequence import SamplerOutput +from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip, + dummy_seq_data_for_clip, get_clip_patch_grid_length) from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector, merge_vision_embeddings @@ -58,41 +60,118 @@ class LlavaNextImageFeatureInputs(TypedDict): LlavaNextImageFeatureInputs] -def _get_dummy_image_data( - seq_len: int, - model_config: ModelConfig, - vlm_config: VisionLanguageConfig, -) -> Tuple[SequenceData, MultiModalData]: - seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config, - vlm_config) +def _get_llava_next_num_unpadded_features( + height: int, + width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + # Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111 + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = width / height + current_aspect_ratio: float = current_width / current_height + if aspect_ratio > current_aspect_ratio: + new_height = (height * current_width) // width + current_height = new_height + else: + new_width = (width * current_height) // height + current_width = new_width + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +def _get_llava_next_image_feature_size( + hf_config: LlavaNextConfig, + *, + input_height: int, + input_width: int, +) -> int: + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + num_patches = get_clip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, + ) + base_feature_size = num_patches * num_patches + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(input_height, input_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_config.image_size, + ) + + ( + unpadded_feature_size, + newline_feature_size, + ) = _get_llava_next_num_unpadded_features(input_height, input_width, + num_patches, + num_patch_height, + num_patch_width) + + return unpadded_feature_size + newline_feature_size + base_feature_size + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + - config_input_type = vlm_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType +def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): + multimodal_config = ctx.get_multimodal_config() + hf_config = ctx.get_hf_config(LlavaNextConfig) + vision_config = hf_config.vision_config + + #TODO: change the logic for dummy data to support dynamic shape + _, _, dummy_height, dummy_width = multimodal_config.image_input_shape + image_feature_size = _get_llava_next_image_feature_size( + hf_config, input_height=dummy_height, input_width=dummy_width) + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + image_input_type = multimodal_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + mm_data: MultiModalData + if image_input_type == ImageInputType.PIXEL_VALUES: + mm_data = dummy_pixel_data_for_clip( + vision_config, + image_width_override=dummy_width, + image_height_override=dummy_height, + ) + elif image_input_type == ImageInputType.IMAGE_FEATURES: + mm_data = dummy_feature_data_for_clip( + vision_config, + image_feature_size_override=image_feature_size, + ) - if config_input_type == ImageInputType.PIXEL_VALUES: - _, c, h, w = vlm_config.image_input_shape - mode = {1: "L", 3: "RGB"}[c] - fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0)) + return seq_data, mm_data - return seq_data, fake_mm_data + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) -def _image_pixel_processor( - data: ImagePixelData, - model_config: ModelConfig, - vlm_config: VisionLanguageConfig, -) -> Dict[str, torch.Tensor]: +def _pixel_mapper(ctx: InputContext, + data: ImagePixelData) -> Dict[str, torch.Tensor]: image = data.image if isinstance(image, torch.Tensor): - pixel_values = image.to(model_config.dtype) + pixel_values = image.to(ctx.model_config.dtype) batch_size, _, _, h, w = pixel_values.shape image_sizes = torch.tensor([(w, h) for _ in range(batch_size)]) return {"pixel_values": pixel_values, "image_sizes": image_sizes} # Temporary patch before dynamic number of image tokens is supported - _, _, h, w = vlm_config.image_input_shape + _, _, h, w = ctx.get_multimodal_config().image_input_shape if (w, h) != (image.width, image.height): logger.warning( "Dynamic image shape is currently not supported. " @@ -101,11 +180,12 @@ def _image_pixel_processor( data.image = image.resize((w, h)) return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \ - ._default_input_processor(data, model_config, vlm_config) + ._default_input_mapper(ctx, data) -@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor) -@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) +@MULTIMODAL_REGISTRY.register_image_feature_input_mapper() +@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): def __init__(self, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0bbe93241b139..5d8ffd5215c52 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -22,7 +22,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig +from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -34,9 +35,10 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import ImagePixelData, get_dummy_image_data +from vllm.multimodal.image import ImagePixelData from vllm.sequence import SamplerOutput +from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsVision logger = init_logger(__name__) @@ -107,7 +109,6 @@ def __init__(self, self.num_img_tokens = config.img_processor['num_img_tokens'] self.image_dim_out = image_dim_out - self.img_sizes = None # global_gn and sub_gn for hd transform, serves as line separator self.use_hd_transform = config.embd_layer.get('use_hd_transform', @@ -134,7 +135,6 @@ def __init__(self, self.img_projection = nn.Sequential(*layers) self.vocab_size = config.vocab_size - self.img_features = None self.layer_idx = config.img_processor.get('layer_idx', -2) self.type_feature = config.img_processor.get('type_feature', 'patch') @@ -260,9 +260,44 @@ class Phi3VImagePixelInputs(TypedDict): """Shape: (batch_size, 2)""" -# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported -# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py -def calc_padded_size(width, height, padding_unit=336): +def _get_phi3v_image_feature_size( + *, + input_height: int, + input_width: int, +) -> int: + h, w = input_height, input_width + + # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178 + return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12 + + +def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): + multimodal_config = ctx.get_multimodal_config() + + #TODO: change the logic for dummy data to support dynamic shape + _, _, dummy_height, dummy_width = multimodal_config.image_input_shape + image_feature_size = _get_phi3v_image_feature_size( + input_height=dummy_height, + input_width=dummy_width, + ) + + seq_data = dummy_seq_data_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + seq_len, + image_token_id=32044, + image_feature_size_override=image_feature_size, + ) + mm_data = dummy_pixel_data_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + image_width_override=dummy_width, + image_height_override=dummy_height, + ) + + return seq_data, mm_data + + +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py +def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): target_height = int(np.ceil(height / padding_unit) * padding_unit) top_padding = int((target_height - height) / 2) bottom_padding = target_height - height - top_padding @@ -271,8 +306,8 @@ def calc_padded_size(width, height, padding_unit=336): return padded_width, padded_height -# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py -def calc_hd_transform_size(width, height, hd_num=16): +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): transposed = False if width < height: width, height = height, width @@ -287,7 +322,8 @@ def calc_hd_transform_size(width, height, hd_num=16): new_width = int(scale * 336) new_height = int(new_width / ratio) - padded_width, padded_height = calc_padded_size(new_width, new_height) + padded_width, padded_height = _calc_padded_size(width=new_width, + height=new_height) if transposed: padded_width, padded_height = padded_height, padded_width @@ -295,17 +331,15 @@ def calc_hd_transform_size(width, height, hd_num=16): return padded_width, padded_height -def _image_processor( - data: ImagePixelData, - model_config: ModelConfig, - vlm_config: VisionLanguageConfig, -) -> Dict[str, torch.Tensor]: +def _image_processor(ctx: InputContext, + data: ImagePixelData) -> Dict[str, torch.Tensor]: image = data.image if isinstance(image, Image.Image): # Temporary patch before dynamic number of image tokens is supported - _, _, h, w = vlm_config.image_input_shape - if (w, h) != calc_hd_transform_size(image.width, image.height): + _, _, h, w = ctx.get_multimodal_config().image_input_shape + if (w, h) != _calc_hd_transform_size(width=image.width, + height=image.height): logger.warning( "Dynamic image shape is currently not supported. " "Resizing input image to (%d, %d).", w, h) @@ -313,11 +347,11 @@ def _image_processor( data.image = image.resize((w, h)) return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \ - ._default_input_processor(data, model_config, vlm_config) + ._default_input_mapper(ctx, data) -@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor) -@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) +@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) class Phi3VForCausalLM(nn.Module, SupportsVision): def __init__(self, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 270012e7d1c3b..20bd87b8c4436 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,5 +1,14 @@ from .base import MultiModalData, MultiModalPlugin -from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry +from .registry import MultiModalRegistry + +MULTIMODAL_REGISTRY = MultiModalRegistry() +""" +The global :class:`~MultiModalRegistry` is used by model runners to +dispatch data processing according to its modality and the target model. + +See also: + :ref:`input_processing_pipeline` +""" __all__ = [ "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 847752449ba80..d47cdd559ad89 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,7 +2,8 @@ from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, TypeVar) -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig +from vllm.inputs import InputContext from vllm.logger import init_logger if TYPE_CHECKING: @@ -23,7 +24,7 @@ class MultiModalData: Finally, register the new plugin to :const:`vllm.multimodal.MULTIMODAL_REGISTRY`. - This enables models to call :meth:`MultiModalRegistry.register_input` for + This enables models to call :meth:`MultiModalRegistry.map_input` for the new modality. """ pass @@ -32,10 +33,9 @@ class MultiModalData: D = TypeVar("D", bound=MultiModalData) N = TypeVar("N", bound=Type["nn.Module"]) -MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], - Dict[str, "torch.Tensor"]] +MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]] """Return a dictionary to be passed as keyword arguments to -:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers +:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers.""" @@ -50,16 +50,9 @@ class MultiModalPlugin(ABC, Generic[D]): (i.e., the modality of the data). """ - @classmethod - def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]: - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - return get_model_architecture(model_config)[0] - def __init__(self) -> None: - self._input_processors: Dict[Type["nn.Module"], - MultiModalInputProcessor[D]] = {} + self._input_mappers: Dict[Type["nn.Module"], + MultiModalInputMapper[D]] = {} @abstractmethod def get_data_type(self) -> Type[D]: @@ -70,57 +63,62 @@ def get_data_type(self) -> Type[D]: raise NotImplementedError @abstractmethod - def _default_input_processor( - self, data: D, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + def _default_input_mapper(self, ctx: InputContext, + data: D) -> Dict[str, "torch.Tensor"]: """Return a dictionary to be passed as keyword arguments to - :meth:`torch.nn.Module.forward`. This is similar in concept to + :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. """ raise NotImplementedError - def register_input_processor(self, - processor: Optional[ - MultiModalInputProcessor[D]] = None): + def register_input_mapper( + self, + mapper: Optional[MultiModalInputMapper[D]] = None, + ): """ - Register an input processor to a model class. + Register an input mapper to a model class. When the model receives input data that matches the modality served by - this plugin (see :meth:`get_data_type`), the provided input processor is - applied to preprocess the data. If `None` is provided, then the default - input processor is applied instead. + this plugin (see :meth:`get_data_type`), the provided function is + invoked to transform the data into a dictionary of model inputs. + If `None` is provided, then the default input mapper is used instead. + + See also: + :ref:`input_processing_pipeline` """ def wrapper(model_cls: N) -> N: - if model_cls in self._input_processors: + if model_cls in self._input_mappers: logger.warning( - "Model class %s already has an input processor " + "Model class %s already has an input mapper " "registered to %s. It is overwritten by the new one.", model_cls, self) - self._input_processors[model_cls] = processor \ - or self._default_input_processor + self._input_mappers[model_cls] = mapper \ + or self._default_input_mapper return model_cls return wrapper - def process_input( - self, data: D, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + def map_input(self, model_config: ModelConfig, + data: D) -> Dict[str, "torch.Tensor"]: """ - Apply an input processor to a :class:`~MultiModalData` instance passed - to the model. - - The model is identified by ``model_config``. ``vlm_config`` is - for compatibility purposes and may be merged into ``model_config`` - in the near future. + Apply an input mapper to a :class:`~MultiModalData` instance passed + to the model, transforming the data into a dictionary of model inputs. + + The model is identified by ``model_config``. + + TODO: Add guide [ref: PR #5276] """ - model_cls = self.get_model_cls(model_config) + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) - processor = self._input_processors.get(model_cls) - if processor is None: - raise KeyError(f"No input processor in {self} is registered for " + mapper = self._input_mappers.get(model_cls) + if mapper is None: + raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return processor(data, model_config, vlm_config) + return mapper(InputContext(model_config), data) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 08fb09d111605..a9691575c2eaf 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,70 +1,28 @@ -from typing import Dict, Tuple, Type, Union +from functools import lru_cache +from typing import Dict, Type, Union import torch from PIL import Image -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig +from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.sequence import SequenceData -from vllm.transformers_utils.image_processor import cached_get_image_processor +from vllm.transformers_utils.image_processor import get_image_processor from .base import MultiModalData, MultiModalPlugin logger = init_logger(__name__) - -def _get_dummy_seq_data(seq_len: int, - vlm_config: VisionLanguageConfig) -> SequenceData: - # NOTE: We assume that token is repeated `image_feature_size` times - # and then concatenated with the text prompt - # TODO: Enable other ways of inserting the image into the prompt - - token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size - token_ids += [0] * (seq_len - vlm_config.image_feature_size) - - return SequenceData(token_ids) - - -def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor: - if vlm_config.image_processor is None: - values_dtype = torch.float16 - else: - values_dtype = torch.uint8 - - return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype) - - -def get_dummy_image_data( - seq_len: int, - model_config: ModelConfig, - vlm_config: VisionLanguageConfig, -) -> Tuple[SequenceData, MultiModalData]: - """Standard dummy data factory for image data (to be used in - :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`).""" - seq_data = _get_dummy_seq_data(seq_len, vlm_config) - values = _get_dummy_values(vlm_config) - - config_input_type = vlm_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - - fake_mm_data: MultiModalData - if config_input_type == ImageInputType.PIXEL_VALUES: - fake_mm_data = ImagePixelData(values) - elif config_input_type == ImageInputType.IMAGE_FEATURES: - fake_mm_data = ImageFeatureData(values) - else: - raise NotImplementedError - - return seq_data, fake_mm_data +cached_get_image_processor = lru_cache(get_image_processor) class ImagePixelData(MultiModalData): """ The pixel data of an image. Can be one of: - - :class:``PIL.Image``: An image object. Requires that a HuggingFace + - :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace processor is available to the model. - - :class:``torch.Tensor``: The raw pixel data which is passed to the model + - :class:`torch.Tensor`: The raw pixel data which is passed to the model without additional pre-processing. """ @@ -89,8 +47,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): def get_data_type(self) -> Type[ImagePixelData]: return ImagePixelData - def _get_hf_image_processor(self, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def _get_hf_image_processor(self, model_config: ModelConfig): + vlm_config = model_config.multimodal_config if vlm_config is None or vlm_config.image_processor is None: return None @@ -100,14 +58,13 @@ def _get_hf_image_processor(self, model_config: ModelConfig, revision=vlm_config.image_processor_revision, ) - def _default_input_processor( - self, data: ImagePixelData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + def _default_input_mapper(self, ctx: InputContext, + data: ImagePixelData) -> Dict[str, torch.Tensor]: + model_config = ctx.model_config image = data.image if isinstance(image, Image.Image): - image_processor = self._get_hf_image_processor( - model_config, vlm_config) + image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available" "to process the image object") @@ -147,9 +104,10 @@ class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): def get_data_type(self) -> Type[ImageFeatureData]: return ImageFeatureData - def _default_input_processor( - self, data: ImageFeatureData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + def _default_input_mapper( + self, ctx: InputContext, + data: ImageFeatureData) -> Dict[str, torch.Tensor]: + model_config = ctx.model_config image_features = data.image_features.to(model_config.dtype) return {"image_features": image_features} diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 4789ce5ce4cfe..abc88e4f9a9d8 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,46 +1,35 @@ import functools -from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, - Tuple, Type, TypeVar) +from typing import Any, Optional, Sequence, Type, TypeVar -from vllm.config import ModelConfig, VisionLanguageConfig +from torch import nn + +from vllm.config import ModelConfig from vllm.logger import init_logger -from .base import MultiModalData, MultiModalPlugin +from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, ImagePixelPlugin) -if TYPE_CHECKING: - import torch - from torch import nn - - from vllm.sequence import SequenceData - logger = init_logger(__name__) D = TypeVar("D", bound=MultiModalData) -N = TypeVar("N", bound=Type["nn.Module"]) - -MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], - Dict[str, "torch.Tensor"]] -MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig], - Tuple["SequenceData", MultiModalData]] +N = TypeVar("N", bound=Type[nn.Module]) class MultiModalRegistry: """ - This registry is used by model runners to dispatch data processing + A registry to dispatch data processing according to its modality and the target model. """ DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) - def __init__(self, - *, - plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS - ) -> None: + def __init__( + self, + *, + plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS, + ) -> None: self._plugins_by_data_type = {p.get_data_type(): p for p in plugins} - self._dummy_factories_by_model_type: Dict[Type["nn.Module"], - MultiModalDummyFactory] = {} def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None: data_type = plugin.get_data_type() @@ -62,95 +51,53 @@ def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]): msg = f"Unknown multi-modal data type: {data_type}" raise NotImplementedError(msg) - def register_dummy_data(self, factory: MultiModalDummyFactory): - """ - Register a dummy data factory to a model class. - - During memory profiling, the provided function is invoked to create - dummy data to be inputted into the model. The modality and shape of - the dummy data should be an upper bound of what the model would receive - at inference time. - """ - - def wrapper(model_cls: N) -> N: - if model_cls in self._dummy_factories_by_model_type: - logger.warning( - "Model class %s already has dummy data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): - """Create dummy data for memory profiling.""" - model_cls = MultiModalPlugin.get_model_cls(model_config) - dummy_factory = self._dummy_factories_by_model_type.get(model_cls) - if dummy_factory is None: - msg = f"No dummy data defined for model class: {model_cls}" - raise NotImplementedError(msg) - - return dummy_factory(seq_len, model_config, vlm_config) - - def register_input( - self, - data_type: Type[D], - processor: Optional[MultiModalInputProcessor[D]] = None): + def register_input_mapper( + self, + data_type: Type[D], + mapper: Optional[MultiModalInputMapper[D]] = None, + ): """ - Register an input processor for a specific modality to a model class. + Register an input mapper for a specific modality to a model class. - See :meth:`MultiModalPlugin.register_input_processor` for more details. + See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ return self._get_plugin_for_data_type(data_type) \ - .register_input_processor(processor) + .register_input_mapper(mapper) - def register_image_pixel_input( - self, - processor: Optional[ - MultiModalInputProcessor[ImagePixelData]] = None): + def register_image_pixel_input_mapper( + self, + mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None, + ): """ - Register an input processor for image pixel data to a model class. + Register an input mapper for image pixel data to a model class. - See :meth:`MultiModalPlugin.register_input_processor` for more details. + See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input(ImagePixelData, processor) + return self.register_input_mapper(ImagePixelData, mapper) - def register_image_feature_input( + def register_image_feature_input_mapper( self, - processor: Optional[ - MultiModalInputProcessor[ImageFeatureData]] = None): + mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None, + ): """ - Register an input processor for image feature data to a model class. + Register an input mapper for image feature data to a model class. - See :meth:`MultiModalPlugin.register_input_processor` for more details. + See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input(ImageFeatureData, processor) + return self.register_input_mapper(ImageFeatureData, mapper) - def process_input(self, data: MultiModalData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def map_input(self, model_config: ModelConfig, data: MultiModalData): """ - Apply an input processor to a :class:`~MultiModalData` instance passed + Apply an input mapper to a :class:`~MultiModalData` instance passed to the model. - See :meth:`MultiModalPlugin.process_input` for more details. + See :meth:`MultiModalPlugin.map_input` for more details. """ return self._get_plugin_for_data_type(type(data)) \ - .process_input(data, model_config, vlm_config) + .map_input(model_config, data) - def create_input_processor(self, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def create_input_mapper(self, model_config: ModelConfig): """ - Create an input processor (see :meth:`process_input`) for a - specific model. + Create an input mapper (see :meth:`map_input`) for a specific model. """ - return functools.partial(self.process_input, - model_config=model_config, - vlm_config=vlm_config) - - -MULTIMODAL_REGISTRY = MultiModalRegistry() -"""The global :class:`~MultiModalRegistry` which is used by model runners.""" + return functools.partial(self.map_input, model_config) diff --git a/vllm/sequence.py b/vllm/sequence.py index c618c36926119..a50aaf4204cb6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,12 +8,12 @@ import torch -from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: + from vllm.inputs import LLMInputs from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -221,7 +221,7 @@ class Sequence: def __init__( self, seq_id: int, - inputs: LLMInputs, + inputs: "LLMInputs", block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 3239b1d0cfa2f..2bb5215d4846f 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -1,4 +1,3 @@ -from functools import lru_cache from typing import Optional from transformers import AutoImageProcessor @@ -40,6 +39,3 @@ def get_image_processor( raise e return processor - - -cached_get_image_processor = lru_cache(get_image_processor) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 148332f3402f0..e689f485e0ea2 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -110,15 +110,9 @@ def __init__( self.block_size, ) - # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ - .create_input_processor( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_processor = None + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -168,13 +162,7 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: - # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) + mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 181442490a82c..93a10070db27a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -15,6 +15,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed.parallel_state import graph_capture +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -25,7 +26,7 @@ from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( @@ -191,15 +192,9 @@ def __init__( self.block_size, ) if num_attn_heads else None - # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ - .create_input_processor( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_processor = None + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model @@ -506,12 +501,7 @@ def _prepare_model_input_tensors( mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) + mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) @@ -764,12 +754,9 @@ def profile_run(self) -> None: seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - if vlm_config is None: - seq_data = SequenceData([0] * seq_len) - dummy_multi_modal_data = None - else: - seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ - .dummy_data_for_profiling(seq_len, model_config, vlm_config) + seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ + .dummy_data_for_profiling(model_config, seq_len) + assert len(seq_data.prompt_token_ids) == seq_len seq = SequenceGroupMetadata( request_id=str(group_id),