Skip to content

Commit

Permalink
Use common context builder in processor kwarg tests
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Sep 20, 2024
1 parent 1cee215 commit d5f9efa
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 72 deletions.
5 changes: 4 additions & 1 deletion tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def check_logprobs_close(
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
processor_kwargs: Optional[Dict] = None):
processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.
Args:
Expand All @@ -256,6 +257,7 @@ def build_model_context(model_name: str,
trust_remote_code: Whether or not to allow loading remote code.
processor_kwargs: optional processor kwargs for to be leveraged
in the input processor, mapper, dummy data creation, etc.
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
Expand All @@ -270,5 +272,6 @@ def build_model_context(model_name: str,
dtype="float32",
seed=0,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
)
return InputContext(model_config)
144 changes: 73 additions & 71 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import pytest
import torch

from vllm.config import ModelConfig
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs.registry import InputRegistry
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

from ..models.utils import build_model_context

# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model
Expand All @@ -24,20 +25,6 @@
NUM_CROPS_OVERRIDE = 16


def get_model_config(model_name, trust_remote_code=False, processor_kwargs=None, limit_mm_per_prompt=None):
"""Creates a handle to a model config, which may have processor kwargs."""
# NOTE - values / architecture don't matter too much here since we patch
# the return values for stuff like the input processor anyway.
return ModelConfig(model_name,
model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype="float16",
seed=0,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt)


# Mocks for all of the places that we use the processor_kwargs
# to override values in different callables
@pytest.fixture
Expand Down Expand Up @@ -78,7 +65,7 @@ def custom_dummy_data_factory(self,
yield


# lambda whose signature matches max token calcs + extra kwargs & mapper respectively
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
Expand All @@ -89,8 +76,8 @@ def custom_dummy_data_factory(self,
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
model_config = get_model_config(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(model_config)
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
Expand All @@ -105,9 +92,9 @@ def test_processor_default_kwargs(use_processor_mock, num_crops):
# otherwise fall back to the default value
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
model_config = get_model_config(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
processor = dummy_registry.create_input_processor(model_config)
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)

num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops
Expand All @@ -117,19 +104,22 @@ def test_processor_default_kwargs(use_processor_mock, num_crops):
"processor_kwargs",
[
# Not part of the signature
{"does_not_exist": 100},
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{"ctx": "something bad"}
{
"ctx": "something bad"
}
])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor_kwargs):
"""Ensure that input processors filter out invalid processor_kwargs."""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)

model_config = get_model_config(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)

processor = dummy_registry.create_input_processor(model_config)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == DEFAULT_NUM_CROPS

Expand All @@ -141,41 +131,45 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
dummy_registry = InputRegistry()
model_config = get_model_config(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
model_config, seq_len=-1, mm_registry=mm_registry)
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count


@pytest.mark.parametrize(
"processor_kwargs",
[
# Not part of the signature
{"does_not_exist": 100},
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{"ctx": "something bad"}
{
"ctx": "something bad"
}
])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
processor_kwargs):
"""Ensure that dummy data factory filters out invalid processor_kwargs."""
dummy_registry = InputRegistry()
model_config = get_model_config(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
model_config, seq_len=-1, mm_registry=mm_registry)
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS


Expand All @@ -186,13 +180,13 @@ def test_max_tokens_kwarg_overrides(num_crops):
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops

model_config = get_model_config(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the processor_kwargs.
Expand All @@ -202,7 +196,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
{Phi3VForCausalLM: get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
model_config)
ctx.model_config)

assert expected_seq_count == max_multimodal_tokens

Expand All @@ -211,19 +205,23 @@ def test_max_tokens_kwarg_overrides(num_crops):
"processor_kwargs",
[
# Not part of the signature
{"does_not_exist": 100},
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{"ctx": "something bad"}
{
"ctx": "something bad"
}
])
def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs):
"""Ensure that max token calcs filters out invalid processor_kwargs."""
model_config = get_model_config(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

# Similar before, but since these kwargs get filtered,
# we always get our default value back.
Expand All @@ -233,7 +231,7 @@ def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs):
{Phi3VForCausalLM: get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
model_config)
ctx.model_config)

assert max_multimodal_tokens == DEFAULT_NUM_CROPS

Expand All @@ -245,18 +243,18 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
model_config = get_model_config(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

image = image_assets[0].pil_image
mm_inputs = {"image": image}

mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1

Expand All @@ -266,13 +264,13 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"""Ensure custom mappers can use processor kwargs."""
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
model_config = get_model_config(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the processor_kwargs.
Expand All @@ -284,7 +282,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"_default_input_mapper",
{Phi3VForCausalLM: custom_mapper},
):
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)

assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1

Expand All @@ -293,20 +291,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"processor_kwargs",
[
# Not part of the signature
{"does_not_exist": 100},
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{"ctx": "something bad"}
{
"ctx": "something bad"
}
])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
processor_kwargs):
"""Ensure that custom mappers filters out invalid processor_kwargs."""
model_config = get_model_config(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(model_config)
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the processor_kwargs.
Expand All @@ -318,6 +320,6 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
"_default_input_mapper",
{Phi3VForCausalLM: custom_mapper},
):
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)

assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1

0 comments on commit d5f9efa

Please sign in to comment.