From 5d6b39f3fe4b77b15b0fae6c98a702f7eaf7f1fc Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 2 Jan 2025 02:21:06 +0000 Subject: [PATCH 01/10] DOne Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 46 +++++------------------------ 1 file changed, 8 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cc25be9f5b6a9..c73e5091e7e8d 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1060,7 +1060,8 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): prompt = inputs.get("prompt") multi_modal_data = inputs.get("multi_modal_data") image = None if multi_modal_data is None else multi_modal_data.get("image") - + if image is None: + return inputs model_config = ctx.model_config processor = cached_get_processor( ctx.model_config.model, @@ -1083,43 +1084,12 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops - if image is not None: - images, image_input_idx, image_masks = pad_images( - max_total_crops, - out["images"], - out["image_input_idx"], - out.get("image_masks"), - ) - else: - base_image_input_size = image_processor.base_image_input_size - image_patch_size = image_processor.image_patch_size - image_num_patch = ( - base_image_input_size[0] // image_patch_size, - base_image_input_size[1] // image_patch_size, - ) - n_pixels = image_patch_size * image_patch_size * 3 - n_patches = image_num_patch[0] * image_num_patch[1] - - image_length_w = image_processor.image_token_length_w - image_length_h = image_processor.image_token_length_h - tokens_per_image = image_length_w * image_length_h - images = torch.full( - (max_total_crops, n_patches, n_pixels), - -1, - dtype=torch.float32, - ) - image_input_idx = torch.full( - (max_total_crops, tokens_per_image), - -1, - dtype=torch.int32, - ) - if image_processor.image_padding_mask: - image_masks = torch.full( - (max_total_crops, n_patches), - -1, - dtype=torch.float32, - ) - + images, image_input_idx, image_masks = pad_images( + max_total_crops, + out["images"], + out["image_input_idx"], + out.get("image_masks"), + ) image_data = dict( images=images, image_input_idx=image_input_idx, From d2e0840ff42fdfd4d07164e17da2e5312c0ad2af Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 2 Jan 2025 02:22:27 +0000 Subject: [PATCH 02/10] Add comment Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index c73e5091e7e8d..48712a50fc8d8 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1060,6 +1060,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): prompt = inputs.get("prompt") multi_modal_data = inputs.get("multi_modal_data") image = None if multi_modal_data is None else multi_modal_data.get("image") + # If there is no image, return directly. if image is None: return inputs model_config = ctx.model_config From 16efb8502e623f9447ce95bfa400b1079e5e5116 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 2 Jan 2025 03:27:12 +0000 Subject: [PATCH 03/10] Optimize logic Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 48712a50fc8d8..2faf15041c3c5 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1060,9 +1060,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): prompt = inputs.get("prompt") multi_modal_data = inputs.get("multi_modal_data") image = None if multi_modal_data is None else multi_modal_data.get("image") - # If there is no image, return directly. - if image is None: - return inputs + model_config = ctx.model_config processor = cached_get_processor( ctx.model_config.model, @@ -1082,6 +1080,16 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): out = processor.process(prompt, image) else: out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) + # + if image is None: + new_prompt_token_ids = out["input_ids"].tolist() + prompt = inputs.get("prompt") + if prompt is None: + prompt = tokenizer.decode(new_prompt_token_ids) + return token_inputs( + prompt_token_ids=new_prompt_token_ids, + prompt=prompt, + ) image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops @@ -1114,11 +1122,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): offset = i size += 1 image_data["image_start_end"] = (offset, offset + size) - prompt = inputs.get("prompt") if prompt is None: prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( prompt_token_ids=new_prompt_token_ids, prompt=prompt, From ef37f8dfc2c3d7249b4b0a78d3397723d768ece5 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 2 Jan 2025 03:30:28 +0000 Subject: [PATCH 04/10] Add comments Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 2faf15041c3c5..67f9ea5f75a96 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1080,7 +1080,8 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): out = processor.process(prompt, image) else: out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) - # + + # If there is no image, return directly. if image is None: new_prompt_token_ids = out["input_ids"].tolist() prompt = inputs.get("prompt") @@ -1090,7 +1091,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): prompt_token_ids=new_prompt_token_ids, prompt=prompt, ) - + image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops images, image_input_idx, image_masks = pad_images( From 9762e8d5cad6cf958b0807b5d33d49a056d77478 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 2 Jan 2025 03:35:03 +0000 Subject: [PATCH 05/10] format Signed-off-by: Jee Jee Li --- vllm/model_executor/models/molmo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 67f9ea5f75a96..0e8287bb56b6b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1080,7 +1080,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): out = processor.process(prompt, image) else: out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) - + # If there is no image, return directly. if image is None: new_prompt_token_ids = out["input_ids"].tolist() @@ -1091,7 +1091,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): prompt_token_ids=new_prompt_token_ids, prompt=prompt, ) - + image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops images, image_input_idx, image_masks = pad_images( From 3b0a807b25c507326e44b4c317dbc213c5816e65 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 3 Jan 2025 13:03:10 +0000 Subject: [PATCH 06/10] Backup Signed-off-by: Jee Jee Li --- .../vision_language/test_models.py | 11 +++ .../vision_language/vlm_utils/model_utils.py | 88 ++++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 7db08166826eb..7a39dc2886e5e 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -346,6 +346,17 @@ ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), + "molmo": VLMTestInfo( + models=["allenai/Molmo-7B-D-0924"], + test_type=(VLMTestType.IMAGE), + prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + image_size_factors=[(),(1.0, 1.0, 1.0)], + patch_hf_runner=model_utils.mlomo_patch_hf_runner, + postprocess_inputs=model_utils.molmo_post_processor + + ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. # https://github.com/huggingface/transformers/issues/34307 diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3eca8fb9dcb1a..a8e6aa0a6fb75 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -5,11 +5,11 @@ import re import types from pathlib import PosixPath -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union,Any,Dict import torch from PIL.Image import Image -from transformers import AutoConfig, AutoTokenizer, BatchEncoding +from transformers import AutoConfig, AutoTokenizer, BatchEncoding,GenerationConfig from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side @@ -222,6 +222,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): return {"model_inputs": hf_inputs} +def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str): + hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype) + return {k: v.unsqueeze(0) for k, v in hf_inputs.items()} + + ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset], @@ -451,3 +456,82 @@ def _generate(self, *args, **kwargs): hf_model.model.generate = types.MethodType(_generate, hf_model.model) return hf_model + +####### Model-specific HuggingFace runner patchers +def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Molmo.""" + hf_processor = hf_model.processor + + + def _processor(*args, **kwargs): + return hf_processor.process(*args, **kwargs) + + def _generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[Any] = None, + audios: Optional[Any] = None, + videos: Optional[Any] = None, + **kwargs: Any, + ) -> List[Any]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + # Process in batches for inference. + if len(all_inputs): + input_ids_lst = [] + images_lst = [] + images_input_idx_lst = [] + imges_masks_lst = [] + for inputs in all_inputs: + input_ids_lst.append(inputs["input_ids"]) + images_lst.append(inputs["images"]) + images_input_idx_lst.append(inputs["image_input_idx"]) + imges_masks_lst.append(inputs["image_masks"]) + batch_inputs={} + batch_inputs['input_ids']=torch.cat(input_ids_lst,dim=0) + batch_inputs['images']=torch.cat(images_lst,dim=0) + batch_inputs['image_input_idx']=torch.cat(images_input_idx_lst,dim=0) + batch_inputs['image_masks']=torch.cat(imges_masks_lst,dim=0) + + outputs = self.model.generate_from_batch( + batch=self.wrap_device(batch_inputs, device=self.model.device.type), + generation_config=GenerationConfig( + max_new_tokens=max_tokens, + stop_strings="<|endoftext|>", + do_sample=False, + ), + tokenizer=self.tokenizer, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for index in range(len(all_inputs)): + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs( + outputs.hidden_states, num_logprobs + ) + all_logprobs.append(seq_logprobs_lst) + seq_ids = outputs.sequences[index] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + hf_model.processor = _processor + hf_model.generate_greedy_logprobs_limit = types.MethodType( + _generate_greedy_logprobs_limit, hf_model + ) + return hf_model \ No newline at end of file From cc831b6db19bca33d5a7ee12495f4b8d4f4bf0c1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 6 Jan 2025 06:03:16 +0000 Subject: [PATCH 07/10] Fix format Signed-off-by: Jee Jee Li --- .../vision_language/vlm_utils/model_utils.py | 149 +++++++++--------- 1 file changed, 76 insertions(+), 73 deletions(-) diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index a8e6aa0a6fb75..042756b54977b 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -5,17 +5,20 @@ import re import types from pathlib import PosixPath -from typing import Callable, List, Optional, Tuple, Union,Any,Dict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from PIL.Image import Image -from transformers import AutoConfig, AutoTokenizer, BatchEncoding,GenerationConfig +from transformers import (AutoConfig, AutoTokenizer, BatchEncoding, + GenerationConfig) from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from .....conftest import HfRunner, ImageAsset, _ImageAssets +from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, + PromptImageInput, PromptVideoInput, _ImageAssets) +from ....utils import TokensTextLogprobs from .types import RunnerOutput @@ -457,81 +460,81 @@ def _generate(self, *args, **kwargs): return hf_model -####### Model-specific HuggingFace runner patchers + +def _generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, + videos: Optional[PromptVideoInput] = None, + **kwargs: Any, +) -> List[TokensTextLogprobs]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + # Process in batches for inference. + if len(all_inputs): + input_ids_lst = [] + images_lst = [] + images_input_idx_lst = [] + imges_masks_lst = [] + for inputs in all_inputs: + input_ids_lst.append(inputs["input_ids"]) + images_lst.append(inputs["images"]) + images_input_idx_lst.append(inputs["image_input_idx"]) + imges_masks_lst.append(inputs["image_masks"]) + batch_inputs = {} + batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0) + batch_inputs['images'] = torch.cat(images_lst, dim=0) + batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst, + dim=0) + batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0) + + outputs = self.model.generate_from_batch( + batch=self.wrap_device(batch_inputs, + device=self.model.device.type), + generation_config=GenerationConfig( + max_new_tokens=max_tokens, + stop_strings="<|endoftext|>", + do_sample=False, + ), + tokenizer=self.tokenizer, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for index in range(len(all_inputs)): + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs(outputs.hidden_states, + num_logprobs) + all_logprobs.append(seq_logprobs_lst) + seq_ids = outputs.sequences[index] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + +####### Molmo-specific HuggingFace runner patchers def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Molmo.""" hf_processor = hf_model.processor - def _processor(*args, **kwargs): return hf_processor.process(*args, **kwargs) - def _generate_greedy_logprobs_limit( - self, - prompts: List[str], - max_tokens: int, - num_logprobs: int, - images: Optional[Any] = None, - audios: Optional[Any] = None, - videos: Optional[Any] = None, - **kwargs: Any, - ) -> List[Any]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) - - # Process in batches for inference. - if len(all_inputs): - input_ids_lst = [] - images_lst = [] - images_input_idx_lst = [] - imges_masks_lst = [] - for inputs in all_inputs: - input_ids_lst.append(inputs["input_ids"]) - images_lst.append(inputs["images"]) - images_input_idx_lst.append(inputs["image_input_idx"]) - imges_masks_lst.append(inputs["image_masks"]) - batch_inputs={} - batch_inputs['input_ids']=torch.cat(input_ids_lst,dim=0) - batch_inputs['images']=torch.cat(images_lst,dim=0) - batch_inputs['image_input_idx']=torch.cat(images_input_idx_lst,dim=0) - batch_inputs['image_masks']=torch.cat(imges_masks_lst,dim=0) - - outputs = self.model.generate_from_batch( - batch=self.wrap_device(batch_inputs, device=self.model.device.type), - generation_config=GenerationConfig( - max_new_tokens=max_tokens, - stop_strings="<|endoftext|>", - do_sample=False, - ), - tokenizer=self.tokenizer, - output_hidden_states=True, - return_dict_in_generate=True, - ) - - all_logprobs: List[List[Dict[int, float]]] = [] - all_output_ids: List[List[int]] = [] - all_output_strs: List[str] = [] - - for index in range(len(all_inputs)): - ( - seq_logprobs_lst, - output_len, - ) = self._hidden_states_to_logprobs( - outputs.hidden_states, num_logprobs - ) - all_logprobs.append(seq_logprobs_lst) - seq_ids = outputs.sequences[index] - output_ids = seq_ids[-output_len:] - all_output_ids.append(output_ids.tolist()) - all_output_strs.append(self.tokenizer.decode(output_ids)) - outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - hf_model.processor = _processor - hf_model.generate_greedy_logprobs_limit = types.MethodType( - _generate_greedy_logprobs_limit, hf_model - ) - return hf_model \ No newline at end of file + HfRunner.generate_greedy_logprobs_limit = _generate_greedy_logprobs_limit + return hf_model From 0f91817f7eba96f74078748206db4562fb8fc4a9 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 6 Jan 2025 06:16:10 +0000 Subject: [PATCH 08/10] Fix format Signed-off-by: Jee Jee Li --- tests/models/decoder_only/vision_language/test_models.py | 2 +- .../decoder_only/vision_language/vlm_utils/model_utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 6f552caf242cd..babee8678d4d2 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -350,7 +350,7 @@ image_size_factors=[(),(1.0, 1.0, 1.0)], patch_hf_runner=model_utils.mlomo_patch_hf_runner, postprocess_inputs=model_utils.molmo_post_processor - + ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 042756b54977b..fca44ca798fa5 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -536,5 +536,7 @@ def _processor(*args, **kwargs): return hf_processor.process(*args, **kwargs) hf_model.processor = _processor - HfRunner.generate_greedy_logprobs_limit = _generate_greedy_logprobs_limit + HfRunner.generate_greedy_logprobs_limit = \ + _generate_greedy_logprobs_limit + return hf_model From 8686011d70c002cae04219528d0ebaf5d5976b46 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 6 Jan 2025 08:50:58 +0000 Subject: [PATCH 09/10] Fix format Signed-off-by: Jee Jee Li --- .../decoder_only/vision_language/vlm_utils/model_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index fca44ca798fa5..6c7a753af787e 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -536,7 +536,11 @@ def _processor(*args, **kwargs): return hf_processor.process(*args, **kwargs) hf_model.processor = _processor - HfRunner.generate_greedy_logprobs_limit = \ - _generate_greedy_logprobs_limit + + setattr( # noqa: B010 + hf_model, + "generate_greedy_logprobs_limit", + types.MethodType(_generate_greedy_logprobs_limit, hf_model), + ) return hf_model From a38c64f45e9ebd77aa9e81f8fb628c526f0c6ae2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 6 Jan 2025 09:28:50 +0000 Subject: [PATCH 10/10] Fix format Signed-off-by: Jee Jee Li --- tests/models/decoder_only/vision_language/test_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index babee8678d4d2..146685738a1d0 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -349,8 +349,7 @@ max_num_seqs=2, image_size_factors=[(),(1.0, 1.0, 1.0)], patch_hf_runner=model_utils.mlomo_patch_hf_runner, - postprocess_inputs=model_utils.molmo_post_processor - + postprocess_inputs=model_utils.molmo_post_processor, ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead.