From ee6fb36788255a60226c5734ef4a0c3963c2b119 Mon Sep 17 00:00:00 2001 From: Jee Jee Li <pandaleefree@gmail.com> Date: Mon, 6 Jan 2025 23:22:25 +0800 Subject: [PATCH] [Bugfix][V1] Fix molmo text-only inputs (#11676) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> --- .../vision_language/test_models.py | 10 ++ .../vision_language/vlm_utils/model_utils.py | 99 ++++++++++++++++++- vllm/model_executor/models/molmo.py | 56 ++++------- 3 files changed, 123 insertions(+), 42 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index dc0b683c1f1cb..146685738a1d0 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -341,6 +341,16 @@ ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), + "molmo": VLMTestInfo( + models=["allenai/Molmo-7B-D-0924"], + test_type=(VLMTestType.IMAGE), + prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + image_size_factors=[(),(1.0, 1.0, 1.0)], + patch_hf_runner=model_utils.mlomo_patch_hf_runner, + postprocess_inputs=model_utils.molmo_post_processor, + ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. # https://github.com/huggingface/transformers/issues/34307 diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3eca8fb9dcb1a..6c7a753af787e 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -5,17 +5,20 @@ import re import types from pathlib import PosixPath -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from PIL.Image import Image -from transformers import AutoConfig, AutoTokenizer, BatchEncoding +from transformers import (AutoConfig, AutoTokenizer, BatchEncoding, + GenerationConfig) from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from .....conftest import HfRunner, ImageAsset, _ImageAssets +from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, + PromptImageInput, PromptVideoInput, _ImageAssets) +from ....utils import TokensTextLogprobs from .types import RunnerOutput @@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): return {"model_inputs": hf_inputs} +def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str): + hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype) + return {k: v.unsqueeze(0) for k, v in hf_inputs.items()} + + ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset], @@ -451,3 +459,88 @@ def _generate(self, *args, **kwargs): hf_model.model.generate = types.MethodType(_generate, hf_model.model) return hf_model + + +def _generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, + videos: Optional[PromptVideoInput] = None, + **kwargs: Any, +) -> List[TokensTextLogprobs]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + # Process in batches for inference. + if len(all_inputs): + input_ids_lst = [] + images_lst = [] + images_input_idx_lst = [] + imges_masks_lst = [] + for inputs in all_inputs: + input_ids_lst.append(inputs["input_ids"]) + images_lst.append(inputs["images"]) + images_input_idx_lst.append(inputs["image_input_idx"]) + imges_masks_lst.append(inputs["image_masks"]) + batch_inputs = {} + batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0) + batch_inputs['images'] = torch.cat(images_lst, dim=0) + batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst, + dim=0) + batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0) + + outputs = self.model.generate_from_batch( + batch=self.wrap_device(batch_inputs, + device=self.model.device.type), + generation_config=GenerationConfig( + max_new_tokens=max_tokens, + stop_strings="<|endoftext|>", + do_sample=False, + ), + tokenizer=self.tokenizer, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for index in range(len(all_inputs)): + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs(outputs.hidden_states, + num_logprobs) + all_logprobs.append(seq_logprobs_lst) + seq_ids = outputs.sequences[index] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + +####### Molmo-specific HuggingFace runner patchers +def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Molmo.""" + hf_processor = hf_model.processor + + def _processor(*args, **kwargs): + return hf_processor.process(*args, **kwargs) + + hf_model.processor = _processor + + setattr( # noqa: B010 + hf_model, + "generate_greedy_logprobs_limit", + types.MethodType(_generate_greedy_logprobs_limit, hf_model), + ) + + return hf_model diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cc25be9f5b6a9..0e8287bb56b6b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): else: out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) - image_processor = processor.image_processor - max_total_crops = 1 + image_processor.max_crops - if image is not None: - images, image_input_idx, image_masks = pad_images( - max_total_crops, - out["images"], - out["image_input_idx"], - out.get("image_masks"), - ) - else: - base_image_input_size = image_processor.base_image_input_size - image_patch_size = image_processor.image_patch_size - image_num_patch = ( - base_image_input_size[0] // image_patch_size, - base_image_input_size[1] // image_patch_size, - ) - n_pixels = image_patch_size * image_patch_size * 3 - n_patches = image_num_patch[0] * image_num_patch[1] - - image_length_w = image_processor.image_token_length_w - image_length_h = image_processor.image_token_length_h - tokens_per_image = image_length_w * image_length_h - images = torch.full( - (max_total_crops, n_patches, n_pixels), - -1, - dtype=torch.float32, - ) - image_input_idx = torch.full( - (max_total_crops, tokens_per_image), - -1, - dtype=torch.int32, + # If there is no image, return directly. + if image is None: + new_prompt_token_ids = out["input_ids"].tolist() + prompt = inputs.get("prompt") + if prompt is None: + prompt = tokenizer.decode(new_prompt_token_ids) + return token_inputs( + prompt_token_ids=new_prompt_token_ids, + prompt=prompt, ) - if image_processor.image_padding_mask: - image_masks = torch.full( - (max_total_crops, n_patches), - -1, - dtype=torch.float32, - ) + image_processor = processor.image_processor + max_total_crops = 1 + image_processor.max_crops + images, image_input_idx, image_masks = pad_images( + max_total_crops, + out["images"], + out["image_input_idx"], + out.get("image_masks"), + ) image_data = dict( images=images, image_input_idx=image_input_idx, @@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): offset = i size += 1 image_data["image_start_end"] = (offset, offset + size) - prompt = inputs.get("prompt") if prompt is None: prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( prompt_token_ids=new_prompt_token_ids, prompt=prompt,