From d30c70ab1cd3f9ee7aa4195a35c85578f27569f4 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 18 Oct 2024 21:04:14 +0000 Subject: [PATCH 1/6] Fix Pixtral batching with multi-image --- vllm/model_executor/models/llava.py | 22 +++++++++++ vllm/model_executor/models/pixtral.py | 55 +++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a83b7d05df7aa..cd6897e5f65dd 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -301,19 +301,41 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + def print_image_structure(images, depth=0): + indent = " " * depth + if isinstance(images, torch.Tensor): + print(f"{indent}Tensor shape: {images.shape}") + elif isinstance(images, list): + print(f"{indent}List length: {len(images)}") + for i, item in enumerate(images): + print(f"{indent}Item {i}:") + print_image_structure(item, depth + 1) + else: + print(f"{indent}Unexpected type: {type(images)}") + # Case for models like PixtralHF that have dynamic image sizes # so we need to produce a list of tensors if image_sizes is not None: images = pixel_values if isinstance(images, torch.Tensor): # if passed as batch take all images + print("Processing tensor input:") + print(f"Original tensor shape: {images.shape}") NN, N, B, C, W, H = images.shape images = images.reshape(NN * N * B, C, W, H) + print(f"Reshaped tensor shape: {images.shape}") images = [images[i] for i in range(images.size(0))] + print("After conversion to list:") + print_image_structure(images) elif isinstance(images, list): # if passed as list flatten lists of tensors + print("Processing list input:") + print_image_structure(images) while isinstance(images, list) and len(images) == 1: + print(f"Unwrapping list of length 1") images = images[0] + print("Final structure after unwrapping:") + print_image_structure(images) # TODO: Add validation based on image_sizes return LlavaImagePixelInputs( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d09cbe5ca02e9..675119a51f50d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -919,25 +919,64 @@ def __init__(self, config: PixtralVisionConfig): self.patch_positional_embedding = PixtralRotaryEmbedding( config, self.device) + def process_image(self, img: torch.Tensor) -> torch.Tensor: + """Process a single image tensor, handling various input shapes.""" + print(f"Original image shape: {img.shape}") + + # Ensure the tensor has at least 4 dimensions (batch, channels, height, width) + if img.dim() < 4: + img = img.unsqueeze(0) # Add batch dimension if not present + + # If there are more than 4 dimensions, flatten all but the last 3 + if img.dim() > 4: + img = img.view(-1, *img.shape[-3:]) + + # Ensure we have the correct number of channels + if img.shape[1] != self.config.num_channels: + raise ValueError(f"Expected {self.config.num_channels} channels, but got {img.shape[1]}") + + img = img.to(self.dtype) + print(f"Processed image shape before patch_conv: {img.shape}") + result = self.patch_conv(img) + print(f"Processed image shape after patch_conv: {result.shape}") + return result + + def process_input(self, item: Union[torch.Tensor, List], depth: int = 0) -> List[torch.Tensor]: + """Recursively process input items, handling nested lists.""" + patch_embeds = [] + indent = " " * depth + print(f"{indent}Type at depth {depth}: {type(item)}") + + if isinstance(item, torch.Tensor): + print(f"{indent}Shape of tensor: {item.shape}") + patch_embeds.append(self.process_image(item)) + elif isinstance(item, list): + print(f"{indent}Length of list: {len(item)}") + for subitem in item: + patch_embeds.extend(self.process_input(subitem, depth + 1)) + else: + raise ValueError(f"Unexpected type in input: {type(item)}") + + return patch_embeds + def forward( self, pixel_values: List[torch.Tensor], ) -> torch.Tensor: """ Args: - pixel_values: tensor of token features for - all tokens of all images of shape (N_toks, D) + pixel_values: Each image to be processed will be a separate tensor + in pixel_values. This means it can be a list of tensors or even a list of lists of tensors + Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ + print(f"Type of pixel_values: {type(pixel_values)}") + print(f"Length of pixel_values: {len(pixel_values)}") + # pass images through initial convolution independently - patch_embeds_list = [ - self.patch_conv( - img.reshape(-1, img.shape[-3], img.shape[-2], - img.shape[-1]).to(self.dtype)) - for img in pixel_values - ] + patch_embeds_list = self.process_input(pixel_values) # flatten to a single sequence patch_embeds = torch.cat( From b5780457b62db7ea6665d66f9f51a01b8d6099da Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 18 Oct 2024 21:05:34 +0000 Subject: [PATCH 2/6] Remove debug print --- vllm/model_executor/models/llava.py | 22 ---------------------- vllm/model_executor/models/pixtral.py | 9 --------- 2 files changed, 31 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index cd6897e5f65dd..a83b7d05df7aa 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -301,41 +301,19 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - def print_image_structure(images, depth=0): - indent = " " * depth - if isinstance(images, torch.Tensor): - print(f"{indent}Tensor shape: {images.shape}") - elif isinstance(images, list): - print(f"{indent}List length: {len(images)}") - for i, item in enumerate(images): - print(f"{indent}Item {i}:") - print_image_structure(item, depth + 1) - else: - print(f"{indent}Unexpected type: {type(images)}") - # Case for models like PixtralHF that have dynamic image sizes # so we need to produce a list of tensors if image_sizes is not None: images = pixel_values if isinstance(images, torch.Tensor): # if passed as batch take all images - print("Processing tensor input:") - print(f"Original tensor shape: {images.shape}") NN, N, B, C, W, H = images.shape images = images.reshape(NN * N * B, C, W, H) - print(f"Reshaped tensor shape: {images.shape}") images = [images[i] for i in range(images.size(0))] - print("After conversion to list:") - print_image_structure(images) elif isinstance(images, list): # if passed as list flatten lists of tensors - print("Processing list input:") - print_image_structure(images) while isinstance(images, list) and len(images) == 1: - print(f"Unwrapping list of length 1") images = images[0] - print("Final structure after unwrapping:") - print_image_structure(images) # TODO: Add validation based on image_sizes return LlavaImagePixelInputs( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 675119a51f50d..f35b7f3d804be 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -921,7 +921,6 @@ def __init__(self, config: PixtralVisionConfig): def process_image(self, img: torch.Tensor) -> torch.Tensor: """Process a single image tensor, handling various input shapes.""" - print(f"Original image shape: {img.shape}") # Ensure the tensor has at least 4 dimensions (batch, channels, height, width) if img.dim() < 4: @@ -936,22 +935,16 @@ def process_image(self, img: torch.Tensor) -> torch.Tensor: raise ValueError(f"Expected {self.config.num_channels} channels, but got {img.shape[1]}") img = img.to(self.dtype) - print(f"Processed image shape before patch_conv: {img.shape}") result = self.patch_conv(img) - print(f"Processed image shape after patch_conv: {result.shape}") return result def process_input(self, item: Union[torch.Tensor, List], depth: int = 0) -> List[torch.Tensor]: """Recursively process input items, handling nested lists.""" patch_embeds = [] - indent = " " * depth - print(f"{indent}Type at depth {depth}: {type(item)}") if isinstance(item, torch.Tensor): - print(f"{indent}Shape of tensor: {item.shape}") patch_embeds.append(self.process_image(item)) elif isinstance(item, list): - print(f"{indent}Length of list: {len(item)}") for subitem in item: patch_embeds.extend(self.process_input(subitem, depth + 1)) else: @@ -972,8 +965,6 @@ def forward( image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ - print(f"Type of pixel_values: {type(pixel_values)}") - print(f"Length of pixel_values: {len(pixel_values)}") # pass images through initial convolution independently patch_embeds_list = self.process_input(pixel_values) From 40883548af1fac5fd8584abaa7d50dd217309b2f Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 18 Oct 2024 21:18:37 +0000 Subject: [PATCH 3/6] Format --- vllm/model_executor/models/pixtral.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index f35b7f3d804be..b99e717d51280 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -922,7 +922,7 @@ def __init__(self, config: PixtralVisionConfig): def process_image(self, img: torch.Tensor) -> torch.Tensor: """Process a single image tensor, handling various input shapes.""" - # Ensure the tensor has at least 4 dimensions (batch, channels, height, width) + # Ensure the tensor has at least 4 dimensions (b, c, h, w) if img.dim() < 4: img = img.unsqueeze(0) # Add batch dimension if not present @@ -932,13 +932,15 @@ def process_image(self, img: torch.Tensor) -> torch.Tensor: # Ensure we have the correct number of channels if img.shape[1] != self.config.num_channels: - raise ValueError(f"Expected {self.config.num_channels} channels, but got {img.shape[1]}") + raise ValueError(f"Expected {self.config.num_channels} channels, " + f"but got {img.shape[1]}") img = img.to(self.dtype) result = self.patch_conv(img) return result - def process_input(self, item: Union[torch.Tensor, List], depth: int = 0) -> List[torch.Tensor]: + def process_input(self, item: Union[torch.Tensor, + List]) -> List[torch.Tensor]: """Recursively process input items, handling nested lists.""" patch_embeds = [] @@ -946,7 +948,7 @@ def process_input(self, item: Union[torch.Tensor, List], depth: int = 0) -> List patch_embeds.append(self.process_image(item)) elif isinstance(item, list): for subitem in item: - patch_embeds.extend(self.process_input(subitem, depth + 1)) + patch_embeds.extend(self.process_input(subitem)) else: raise ValueError(f"Unexpected type in input: {type(item)}") @@ -959,7 +961,9 @@ def forward( """ Args: pixel_values: Each image to be processed will be a separate tensor - in pixel_values. This means it can be a list of tensors or even a list of lists of tensors + in pixel_values. This means it can be a list of tensors, or + even a list of lists of tensors in the case of multiple + requests batched that have multiple images each Returns: image_features: tensor of token features for From b12510522fc2a8766651028427102fdad8b01663 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 21 Oct 2024 15:02:26 +0000 Subject: [PATCH 4/6] Simplify everything by using a list of tensors --- vllm/model_executor/models/llava.py | 29 +++++++++++++------ vllm/model_executor/models/pixtral.py | 41 +++------------------------ 2 files changed, 24 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a83b7d05df7aa..aa8f611aa1604 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -305,15 +305,26 @@ def _parse_and_validate_image_input( # so we need to produce a list of tensors if image_sizes is not None: images = pixel_values - if isinstance(images, torch.Tensor): - # if passed as batch take all images - NN, N, B, C, W, H = images.shape - images = images.reshape(NN * N * B, C, W, H) - images = [images[i] for i in range(images.size(0))] - elif isinstance(images, list): - # if passed as list flatten lists of tensors - while isinstance(images, list) and len(images) == 1: - images = images[0] + + def flatten_to_3d_tensors(item): + if isinstance(item, torch.Tensor): + if item.dim() == 3: + return [item] + elif item.dim() > 3: + return [t for t in item.view(-1, *item.shape[-3:])] + else: + raise ValueError( + f"Unexpected tensor dimension: {item.dim()}") + elif isinstance(item, list): + return [ + t for subitem in item + for t in flatten_to_3d_tensors(subitem) + ] + else: + raise ValueError(f"Unexpected type: {type(item)}") + + # Restructure the batched images into a list of lists of images + images = flatten_to_3d_tensors(pixel_values) # TODO: Add validation based on image_sizes return LlavaImagePixelInputs( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d493ffb78aa09..aa762f53c7b01 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -901,41 +901,6 @@ def __init__(self, config: PixtralVisionConfig): self.patch_positional_embedding = PixtralRotaryEmbedding( config, self.device) - def process_image(self, img: torch.Tensor) -> torch.Tensor: - """Process a single image tensor, handling various input shapes.""" - - # Ensure the tensor has at least 4 dimensions (b, c, h, w) - if img.dim() < 4: - img = img.unsqueeze(0) # Add batch dimension if not present - - # If there are more than 4 dimensions, flatten all but the last 3 - if img.dim() > 4: - img = img.view(-1, *img.shape[-3:]) - - # Ensure we have the correct number of channels - if img.shape[1] != self.config.num_channels: - raise ValueError(f"Expected {self.config.num_channels} channels, " - f"but got {img.shape[1]}") - - img = img.to(self.dtype) - result = self.patch_conv(img) - return result - - def process_input(self, item: Union[torch.Tensor, - List]) -> List[torch.Tensor]: - """Recursively process input items, handling nested lists.""" - patch_embeds = [] - - if isinstance(item, torch.Tensor): - patch_embeds.append(self.process_image(item)) - elif isinstance(item, list): - for subitem in item: - patch_embeds.extend(self.process_input(subitem)) - else: - raise ValueError(f"Unexpected type in input: {type(item)}") - - return patch_embeds - def forward( self, pixel_values: List[torch.Tensor], @@ -951,9 +916,11 @@ def forward( image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ - # pass images through initial convolution independently - patch_embeds_list = self.process_input(pixel_values) + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) + for img in pixel_values + ] # flatten to a single sequence patch_embeds = torch.cat( From 47c20a12ae9a61617d66a3219a33cfa5f522200b Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 21 Oct 2024 15:05:27 +0000 Subject: [PATCH 5/6] Cleanup --- vllm/model_executor/models/llava.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index aa8f611aa1604..d878280de2bd2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -308,9 +308,7 @@ def _parse_and_validate_image_input( def flatten_to_3d_tensors(item): if isinstance(item, torch.Tensor): - if item.dim() == 3: - return [item] - elif item.dim() > 3: + if item.dim() >= 3: return [t for t in item.view(-1, *item.shape[-3:])] else: raise ValueError( From 954e99b521987d692da041a4d7d0e28bbd776abc Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 21 Oct 2024 16:24:56 +0000 Subject: [PATCH 6/6] Add image_size validation --- vllm/model_executor/models/llava.py | 31 +++++++++++++++++++++++++-- vllm/model_executor/models/pixtral.py | 6 +++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d878280de2bd2..a666dcba290f2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -287,6 +287,34 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: return data + def _validate_image_sizes(self, images: List[torch.Tensor], + sizes: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(sizes, list): + sizes = [sizes] + + total_images = sum(size.numel() // 2 for size in sizes) + if total_images != len(images): + raise ValueError("Mismatch in number of images. " + f"Expected {total_images}, got {len(images)}") + img_idx = 0 + for size in sizes: + # Flatten the size tensor to a list of (height, width) pairs + size = size.view(-1, 2).tolist() + for expected_h, expected_w in size: + if img_idx >= len(images): + raise ValueError("Ran out of images before sizes. " + f"{img_idx} >= {len(images)}") + img = images[img_idx] + if img.shape[-2:] != (expected_h, expected_w): + raise ValueError( + "Image size mismatch. Expected " + f"{(expected_h, expected_w)}, got {img.shape[-2:]}") + if img.shape[-3] != 3: + raise ValueError("Image channel mismatch. Expected 3, " + f"got {img.shape[-3]}") + img_idx += 1 + return images + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -324,10 +352,9 @@ def flatten_to_3d_tensors(item): # Restructure the batched images into a list of lists of images images = flatten_to_3d_tensors(pixel_values) - # TODO: Add validation based on image_sizes return LlavaImagePixelInputs( type="pixel_values", - data=images, + data=self._validate_image_sizes(images, image_sizes), ) return LlavaImagePixelInputs( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index aa762f53c7b01..f33871c0d5acc 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -908,9 +908,9 @@ def forward( """ Args: pixel_values: Each image to be processed will be a separate tensor - in pixel_values. This means it can be a list of tensors, or - even a list of lists of tensors in the case of multiple - requests batched that have multiple images each + in pixel_values. This means it will be a list of tensors + because multiple requests batched can have multiple images, + each with their own shape potentially Returns: image_features: tensor of token features for