Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Bugfix] Fix batching with multi-image in PixtralHF #9518

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,34 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

return data

def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]

total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -305,20 +333,28 @@ def _parse_and_validate_image_input(
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]

# TODO: Add validation based on image_sizes

def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")

# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)

return LlavaImagePixelInputs(
type="pixel_values",
data=images,
data=self._validate_image_sizes(images, image_sizes),
)

return LlavaImagePixelInputs(
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,17 +907,18 @@ def forward(
) -> torch.Tensor:
"""
Args:
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially

Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(
img.reshape(-1, img.shape[-3], img.shape[-2],
img.shape[-1]).to(self.dtype))
self.patch_conv(img.unsqueeze(0).to(self.dtype))
for img in pixel_values
]

Expand Down