Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Bugfix] Fix batching with multi-image in PixtralHF #9518

Merged
merged 7 commits into from
Oct 21, 2024

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Oct 18, 2024

Before this PR, PixtralHF would fail during forward passes where pixel_values would be a list of lists of tensors. This can happen when multiple requests are batch that have multiple images each. Since each image can have a different shape, we end up with many separate tensors.

Due to how the HF processor works, when we call _parse_and_validate_image_input we can receive pixel_values as inputs that are either just a Tensor, a list of Tensors, or some mismatched list of Tensor/lists. This PR adds a normalization pass to pixel_values such that all of the image tensors from all requests are unrolled into a list of unbatched 3D tensors. This is the simplest way to avoid the complexity of some sub-requests batching and some not. This could carry a performance penalty in some cases where we could be using batched tensors, but in practice I think it will be uncommon to have images of exactly the same size.

I will provide some before+after examples of the types of pixel_values we recieve and how we now normalize them to simply a List[torch.Tensor].

Example 1: 4 images of size 1024x1024 (this is the kv cache memory profiling pass when limit_mm_per_prompt={"image": 4})

Initial structure before restructuring:
Tensor shape: torch.Size([1, 1, 4, 3, 1024, 1024])

Final structure after restructuring:
List length: 4
Item 0:
  Tensor shape: torch.Size([3, 1024, 1024])
Item 1:
  Tensor shape: torch.Size([3, 1024, 1024])
Item 2:
  Tensor shape: torch.Size([3, 1024, 1024])
Item 3:
  Tensor shape: torch.Size([3, 1024, 1024])

Example 2: Batch of 3 requests; 1. with 1 image of 688x1024, 2. with 3 images of 688x1024 and 1 image of 704x1024, and 3. with 2 images of 688x1024

Initial structure before restructuring:
List length: 3
Item 0:
  Tensor shape: torch.Size([1, 1, 3, 688, 1024])
Item 1:
  List length: 1
  Item 0:
    List length: 4
    Item 0:
      Tensor shape: torch.Size([3, 688, 1024])
    Item 1:
      Tensor shape: torch.Size([3, 704, 1024])
    Item 2:
      Tensor shape: torch.Size([3, 688, 1024])
    Item 3:
      Tensor shape: torch.Size([3, 688, 1024])
Item 2:
  Tensor shape: torch.Size([1, 2, 3, 688, 1024])

Final structure after restructuring:
List length: 7
Item 0:
  Tensor shape: torch.Size([3, 688, 1024])
Item 1:
  Tensor shape: torch.Size([3, 688, 1024])
Item 2:
  Tensor shape: torch.Size([3, 704, 1024])
Item 3:
  Tensor shape: torch.Size([3, 688, 1024])
Item 4:
  Tensor shape: torch.Size([3, 688, 1024])
Item 5:
  Tensor shape: torch.Size([3, 688, 1024])
Item 6:
  Tensor shape: torch.Size([3, 688, 1024])

Minimal reproducible edge case

I also used this custom case for testing offline batch with multiple-images of both same sizes and different sizes:

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

sampling_params = SamplingParams(temperature=0.0, max_tokens=100)

model_name = "nm-testing/pixtral-12b-FP8-dynamic"
llm = LLM(model=model_name,
            max_num_seqs=4,
            enforce_eager=True,
            max_model_len=30000,
            limit_mm_per_prompt={"image": 4})

image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image2 = ImageAsset("stop_sign").pil_image.convert("RGB")
input1 = {
    "prompt": "<s>[INST][IMG]Describe the image.[/INST]",
    "multi_modal_data": {"image": image1},
}
input2 = {
    "prompt": "<s>[INST][IMG][IMG][IMG][IMG]How many duplicated images are there?[/INST]",
    "multi_modal_data": {"image": [image1, image2, image1, image1]},
}
input3 = {
    "prompt": "<s>[INST][IMG][IMG]Are the images the same?[/INST]",
    "multi_modal_data": {"image": [image1, image1]},
}
outputs = llm.generate([input1, input2, input3], sampling_params=sampling_params)

for i, output in enumerate(outputs):
    print(f"\nOutput #{i}:", output.outputs[0].text)

Output:

Output #0: The image captures a breathtaking view of the Tokyo Skytree, the tallest tower in Japan, standing majestically against the backdrop of a clear blue sky. The tower, painted in a pristine white, is adorned with a distinctive red and white striped pattern near the top, adding a touch of color to the otherwise monochrome structure. The perspective of the photo is from a low angle, looking up towards the tower, emphasizing its impressive height and grandeur. In the foreground

Output #1: There are three duplicated images.

Output #2: Yes, the images are identical. They both feature a tall tower with a circular observation deck and a spire, framed by cherry blossom branches with pink flowers against a clear blue sky. The perspective and composition of the images are the same, with the tower centered and the cherry blossoms in the foreground.

Validation

With this PR, I am able to reproduce the MMMU benchmark for Pixtral:

vllm serve nm-testing/pixtral-12b-FP8-dynamic --max-num-seqs 8 --enforce-eager --max-model-len 30000 --port 9000 --disable-frontend-multiprocessing --limit-mm-per-prompt 'image=5'
 
python -m eval.run eval_vllm --model_name nm-testing/pixtral-12b-FP8-dynamic --url http://0.0.0.0:9000 --output_dir output/ --eval_name "mmmu"
================================================================================
Metrics:
{
    "explicit_prompt_relaxed_correctness": 0.5088888888888888,
    "anywhere_in_answer_relaxed_correctness": 0.5088888888888888
}
================================================================================

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

We have flatten_bn utility function which accepts both batched tensor and nested list. See if this can simplify your implementation.

@mgoin
Copy link
Member Author

mgoin commented Oct 21, 2024

Hi @DarkLight1337 I've improved the change now by always producing a list of 3D tensors from pixel_values. This has slight performance implications by removing batching from the patch_conv, but greatly simplifies the logic. Please see the description for more details on why I made this choice, interested to hear your thoughts

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 21, 2024

The number of dimensions in pixel_values should be the same regardless of single- or multi- image input, the only difference being whether they come in the form of nested lists or a single stacked tensor (for single image input, the dimension for number of images should have size 1, not squeezed, for the HF processors I've seen so far). Is there a particular reason why flatten_bn is not sufficient for this?

@mgoin
Copy link
Member Author

mgoin commented Oct 21, 2024

flatten_bn is not sufficient because the pixel_values we get from HF are not standard like you are expecting. See some of the details in their preprocess function.

Here is the structure from one of the batches in the description. You can see that there is an overall list with one entry for each request, but there may be nested lists with tensors of different rank if the images do not have the same shape.

Initial structure before restructuring:
List length: 3
Item 0:
  Tensor shape: torch.Size([1, 1, 3, 688, 1024])
Item 1:
  List length: 1
  Item 0:
    List length: 4
    Item 0:
      Tensor shape: torch.Size([3, 688, 1024])
    Item 1:
      Tensor shape: torch.Size([3, 704, 1024])
    Item 2:
      Tensor shape: torch.Size([3, 688, 1024])
    Item 3:
      Tensor shape: torch.Size([3, 688, 1024])
Item 2:
  Tensor shape: torch.Size([1, 2, 3, 688, 1024])

I use this function to print this structure of pixel_values

def print_image_structure(images, depth=0):
  indent = "  " * depth
  if isinstance(images, torch.Tensor):
      print(f"{indent}Tensor shape: {images.shape}")
  elif isinstance(images, list):
      print(f"{indent}List length: {len(images)}")
      for i, item in enumerate(images):
          print(f"{indent}Item {i}:")
          print_image_structure(item, depth + 1)
  else:
      print(f"{indent}Unexpected type: {type(images)}")

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 21, 2024

I see, thanks for the explanation! So the overall rank is the same (5-D input), but even after flatten_bn the output may still be a nested list. In that case I'm fine with your implementation, since the image size may be different for every image.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 21, 2024
@mgoin mgoin merged commit 5241aa1 into main Oct 21, 2024
58 checks passed
charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Oct 23, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
…ect#9518)

Signed-off-by: Erkin Sagiroglu <erkin@infra-aipipeline-1-at1-prox-prod-a.ipa.corp.telnyx.com>
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
@simon-mo simon-mo deleted the fix-pixtral-batching branch October 28, 2024 16:50
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…ect#9518)

Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…ect#9518)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants