From d5f9efa94a80e2a4751a69f109027df0334789c7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:55:42 -0400 Subject: [PATCH] Use common context builder in processor kwarg tests Signed-off-by: Alex-Brooks --- tests/models/utils.py | 5 +- tests/multimodal/test_processor_kwargs.py | 144 +++++++++++----------- 2 files changed, 77 insertions(+), 72 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 0c3e876dd6cdc..77a7e054bf683 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -247,7 +247,8 @@ def check_logprobs_close( def build_model_context(model_name: str, tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, - processor_kwargs: Optional[Dict] = None): + processor_kwargs: Optional[Dict] = None, + limit_mm_per_prompt: Optional[Dict] = None): """Creates an InputContext for a given model. Args: @@ -256,6 +257,7 @@ def build_model_context(model_name: str, trust_remote_code: Whether or not to allow loading remote code. processor_kwargs: optional processor kwargs for to be leveraged in the input processor, mapper, dummy data creation, etc. + limit_mm_per_prompt: Multimodal limits. Returns: InputContext for the model being considered. @@ -270,5 +272,6 @@ def build_model_context(model_name: str, dtype="float32", seed=0, processor_kwargs=processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, ) return InputContext(model_config) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 1c84cd265e26c..35df3fe1492e4 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,13 +5,14 @@ import pytest import torch -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from ..models.utils import build_model_context + # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" # Used for tests that need a multimodal model @@ -24,20 +25,6 @@ NUM_CROPS_OVERRIDE = 16 -def get_model_config(model_name, trust_remote_code=False, processor_kwargs=None, limit_mm_per_prompt=None): - """Creates a handle to a model config, which may have processor kwargs.""" - # NOTE - values / architecture don't matter too much here since we patch - # the return values for stuff like the input processor anyway. - return ModelConfig(model_name, - model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt=limit_mm_per_prompt) - - # Mocks for all of the places that we use the processor_kwargs # to override values in different callables @pytest.fixture @@ -78,7 +65,7 @@ def custom_dummy_data_factory(self, yield -# lambda whose signature matches max token calcs + extra kwargs & mapper respectively +# lambda whose signature matches max token calcs extra & mapper + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) @@ -89,8 +76,8 @@ def custom_dummy_data_factory(self, def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID) - processor = dummy_registry.create_input_processor(model_config) + ctx = build_model_context(DUMMY_MODEL_ID) + processor = dummy_registry.create_input_processor(ctx.model_config) proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs @@ -105,9 +92,9 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): # otherwise fall back to the default value processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) - processor = dummy_registry.create_input_processor(model_config) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) + processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == expected_num_crops @@ -117,19 +104,22 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): """Ensure that input processors filter out invalid processor_kwargs.""" dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) - - processor = dummy_registry.create_input_processor(model_config) + processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == DEFAULT_NUM_CROPS @@ -141,16 +131,16 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, seq_len=-1, mm_registry=mm_registry) + ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count @@ -158,24 +148,28 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): """Ensure that dummy data factory filters out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, seq_len=-1, mm_registry=mm_registry) + ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS @@ -186,13 +180,13 @@ def test_max_tokens_kwarg_overrides(num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -202,7 +196,7 @@ def test_max_tokens_kwarg_overrides(num_crops): {Phi3VForCausalLM: get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( - model_config) + ctx.model_config) assert expected_seq_count == max_multimodal_tokens @@ -211,19 +205,23 @@ def test_max_tokens_kwarg_overrides(num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): """Ensure that max token calcs filters out invalid processor_kwargs.""" - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Similar before, but since these kwargs get filtered, # we always get our default value back. @@ -233,7 +231,7 @@ def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): {Phi3VForCausalLM: get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( - model_config) + ctx.model_config) assert max_multimodal_tokens == DEFAULT_NUM_CROPS @@ -245,18 +243,18 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) image = image_assets[0].pil_image mm_inputs = {"image": image} - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 @@ -266,13 +264,13 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): """Ensure custom mappers can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -284,7 +282,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): "_default_input_mapper", {Phi3VForCausalLM: custom_mapper}, ): - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 @@ -293,20 +291,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, processor_kwargs): """Ensure that custom mappers filters out invalid processor_kwargs.""" - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -318,6 +320,6 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, "_default_input_mapper", {Phi3VForCausalLM: custom_mapper}, ): - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1