Skip to content

Commit

Permalink
[Model][Bugfix] Fix batching with multi-image in PixtralHF (vllm-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Oct 21, 2024
1 parent 1b2290b commit b55ffdf
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 17 deletions.
60 changes: 48 additions & 12 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,34 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

return data

def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]

total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
Expand All @@ -305,20 +333,28 @@ def _parse_and_validate_image_input(
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]

# TODO: Add validation based on image_sizes

def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")

# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)

return LlavaImagePixelInputs(
type="pixel_values",
data=images,
data=self._validate_image_sizes(images, image_sizes),
)

return LlavaImagePixelInputs(
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,17 +907,18 @@ def forward(
) -> torch.Tensor:
"""
Args:
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(
img.reshape(-1, img.shape[-3], img.shape[-2],
img.shape[-1]).to(self.dtype))
self.patch_conv(img.unsqueeze(0).to(self.dtype))
for img in pixel_values
]

Expand Down

0 comments on commit b55ffdf

Please sign in to comment.