From ee6fb36788255a60226c5734ef4a0c3963c2b119 Mon Sep 17 00:00:00 2001
From: Jee Jee Li <pandaleefree@gmail.com>
Date: Mon, 6 Jan 2025 23:22:25 +0800
Subject: [PATCH] [Bugfix][V1] Fix molmo text-only inputs (#11676)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
---
 .../vision_language/test_models.py            | 10 ++
 .../vision_language/vlm_utils/model_utils.py  | 99 ++++++++++++++++++-
 vllm/model_executor/models/molmo.py           | 56 ++++-------
 3 files changed, 123 insertions(+), 42 deletions(-)

diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index dc0b683c1f1cb..146685738a1d0 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -341,6 +341,16 @@
         ),
         hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
     ),
+    "molmo": VLMTestInfo(
+        models=["allenai/Molmo-7B-D-0924"],
+        test_type=(VLMTestType.IMAGE),
+        prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
+        max_model_len=4096,
+        max_num_seqs=2,
+        image_size_factors=[(),(1.0, 1.0, 1.0)],
+        patch_hf_runner=model_utils.mlomo_patch_hf_runner,
+        postprocess_inputs=model_utils.molmo_post_processor,
+    ),
     # Tests for phi3v currently live in another file because of a bug in
     # transformers. Once this issue is fixed, we can enable them here instead.
     # https://github.com/huggingface/transformers/issues/34307
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index 3eca8fb9dcb1a..6c7a753af787e 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -5,17 +5,20 @@
 import re
 import types
 from pathlib import PosixPath
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import torch
 from PIL.Image import Image
-from transformers import AutoConfig, AutoTokenizer, BatchEncoding
+from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
+                          GenerationConfig)
 
 from vllm.sequence import SampleLogprobs
 from vllm.transformers_utils.tokenizer import patch_padding_side
 from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
 
-from .....conftest import HfRunner, ImageAsset, _ImageAssets
+from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
+                           PromptImageInput, PromptVideoInput, _ImageAssets)
+from ....utils import TokensTextLogprobs
 from .types import RunnerOutput
 
 
@@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
     return {"model_inputs": hf_inputs}
 
 
+def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
+    hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
+    return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
+
+
 ####### Prompt path encoders for models that need models on disk
 def qwen_prompt_path_encoder(
         tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
@@ -451,3 +459,88 @@ def _generate(self, *args, **kwargs):
     hf_model.model.generate = types.MethodType(_generate, hf_model.model)
 
     return hf_model
+
+
+def _generate_greedy_logprobs_limit(
+    self,
+    prompts: List[str],
+    max_tokens: int,
+    num_logprobs: int,
+    images: Optional[PromptImageInput] = None,
+    audios: Optional[PromptAudioInput] = None,
+    videos: Optional[PromptVideoInput] = None,
+    **kwargs: Any,
+) -> List[TokensTextLogprobs]:
+    all_inputs = self.get_inputs(prompts,
+                                 images=images,
+                                 videos=videos,
+                                 audios=audios)
+
+    # Process in batches for inference.
+    if len(all_inputs):
+        input_ids_lst = []
+        images_lst = []
+        images_input_idx_lst = []
+        imges_masks_lst = []
+        for inputs in all_inputs:
+            input_ids_lst.append(inputs["input_ids"])
+            images_lst.append(inputs["images"])
+            images_input_idx_lst.append(inputs["image_input_idx"])
+            imges_masks_lst.append(inputs["image_masks"])
+        batch_inputs = {}
+        batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
+        batch_inputs['images'] = torch.cat(images_lst, dim=0)
+        batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
+                                                    dim=0)
+        batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
+
+        outputs = self.model.generate_from_batch(
+            batch=self.wrap_device(batch_inputs,
+                                   device=self.model.device.type),
+            generation_config=GenerationConfig(
+                max_new_tokens=max_tokens,
+                stop_strings="<|endoftext|>",
+                do_sample=False,
+            ),
+            tokenizer=self.tokenizer,
+            output_hidden_states=True,
+            return_dict_in_generate=True,
+        )
+
+    all_logprobs: List[List[Dict[int, float]]] = []
+    all_output_ids: List[List[int]] = []
+    all_output_strs: List[str] = []
+
+    for index in range(len(all_inputs)):
+        (
+            seq_logprobs_lst,
+            output_len,
+        ) = self._hidden_states_to_logprobs(outputs.hidden_states,
+                                            num_logprobs)
+        all_logprobs.append(seq_logprobs_lst)
+        seq_ids = outputs.sequences[index]
+        output_ids = seq_ids[-output_len:]
+        all_output_ids.append(output_ids.tolist())
+        all_output_strs.append(self.tokenizer.decode(output_ids))
+    outputs = zip(all_output_ids, all_output_strs, all_logprobs)
+    return [(output_ids, output_str, output_logprobs)
+            for output_ids, output_str, output_logprobs in outputs]
+
+
+####### Molmo-specific HuggingFace runner patchers
+def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
+    """Patches and returns an instance of the HfRunner to use for Molmo."""
+    hf_processor = hf_model.processor
+
+    def _processor(*args, **kwargs):
+        return hf_processor.process(*args, **kwargs)
+
+    hf_model.processor = _processor
+
+    setattr(  # noqa: B010
+        hf_model,
+        "generate_greedy_logprobs_limit",
+        types.MethodType(_generate_greedy_logprobs_limit, hf_model),
+    )
+
+    return hf_model
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index cc25be9f5b6a9..0e8287bb56b6b 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
     else:
         out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
 
-    image_processor = processor.image_processor
-    max_total_crops = 1 + image_processor.max_crops
-    if image is not None:
-        images, image_input_idx, image_masks = pad_images(
-            max_total_crops,
-            out["images"],
-            out["image_input_idx"],
-            out.get("image_masks"),
-        )
-    else:
-        base_image_input_size = image_processor.base_image_input_size
-        image_patch_size = image_processor.image_patch_size
-        image_num_patch = (
-            base_image_input_size[0] // image_patch_size,
-            base_image_input_size[1] // image_patch_size,
-        )
-        n_pixels = image_patch_size * image_patch_size * 3
-        n_patches = image_num_patch[0] * image_num_patch[1]
-
-        image_length_w = image_processor.image_token_length_w
-        image_length_h = image_processor.image_token_length_h
-        tokens_per_image = image_length_w * image_length_h
-        images = torch.full(
-            (max_total_crops, n_patches, n_pixels),
-            -1,
-            dtype=torch.float32,
-        )
-        image_input_idx = torch.full(
-            (max_total_crops, tokens_per_image),
-            -1,
-            dtype=torch.int32,
+    # If there is no image, return directly.
+    if image is None:
+        new_prompt_token_ids = out["input_ids"].tolist()
+        prompt = inputs.get("prompt")
+        if prompt is None:
+            prompt = tokenizer.decode(new_prompt_token_ids)
+        return token_inputs(
+            prompt_token_ids=new_prompt_token_ids,
+            prompt=prompt,
         )
-        if image_processor.image_padding_mask:
-            image_masks = torch.full(
-                (max_total_crops, n_patches),
-                -1,
-                dtype=torch.float32,
-            )
 
+    image_processor = processor.image_processor
+    max_total_crops = 1 + image_processor.max_crops
+    images, image_input_idx, image_masks = pad_images(
+        max_total_crops,
+        out["images"],
+        out["image_input_idx"],
+        out.get("image_masks"),
+    )
     image_data = dict(
         images=images,
         image_input_idx=image_input_idx,
@@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
                 offset = i
             size += 1
     image_data["image_start_end"] = (offset, offset + size)
-
     prompt = inputs.get("prompt")
     if prompt is None:
         prompt = tokenizer.decode(new_prompt_token_ids)
-
     return token_inputs(
         prompt_token_ids=new_prompt_token_ids,
         prompt=prompt,