-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PixtralLarge
] Update Pixtral conversion script to support large format!
#34801
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
aba20bf
to
24d9ee5
Compare
…n case I need to revert
d599d5d
to
99ea497
Compare
This should be just about ready! Quick summary of the changes:
TODO:
|
# Conflicts: # src/transformers/models/pixtral/modeling_pixtral.py
def _recursive_to(obj, device, *args, **kwargs): | ||
# Lists can be nested, so keep digging until we hit tensors | ||
if isinstance(obj, list): | ||
return [_recursive_to(o, device, *args, **kwargs) for o in obj] | ||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` | ||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): | ||
# cast and send to device | ||
return obj.to(*args, **kwargs) | ||
elif isinstance(obj, torch.Tensor) and device is not None: | ||
# only send to device, don't cast | ||
return obj.to(device=device) | ||
else: | ||
return obj | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer: The previous BatchFeature.to()
actually flattened the structure of nested inputs, which created several bugs! This fix preserves nested structure
if isinstance(text, str) or isinstance(text, list) and len(text) == 1: | ||
# If there's a single sample, the image must belong to it | ||
images = [[images]] | ||
else: | ||
raise ValueError( | ||
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." | ||
) | ||
elif isinstance(images, list) and is_image_or_image_url(images[0]): | ||
if isinstance(text, str) or isinstance(text, list) and len(text) == 1: | ||
# If there's a single sample, all images must belong to it | ||
images = [images] | ||
else: | ||
raise ValueError( | ||
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer: Previously there were a lot of edge cases when users passed a single list of images. In some cases, the processor interpreted this as one image per sample rather than a list of images for one sample. This code avoids these error-prone inferences.
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values] | ||
if len(pixel_values) > 1: | ||
raise ValueError("Batching/padding not supported yet!") | ||
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample] | ||
|
||
# flatten to a single sequence | ||
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) | ||
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) | ||
patch_embeds = self.ln_pre(patch_embeds) | ||
|
||
# positional embeddings | ||
position_ids = position_ids_in_meshgrid( | ||
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size | ||
).to(self.device) | ||
|
||
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer: These changes are here to handle images being passed in as a list of lists now. Previously, images were passed in as a flat list even though the processor output a list of lists. The only reason this didn't cause an error was because the bug in BatchFeature.to()
silently fixed the list structure and made it match the modeling code 😓
This should be ready for final review @ArthurZucker! I did ablation testing and reverted some of the dtype changes in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's roll! A todo is to add another test for the new model 😉 Good to go otherwise
def _recursive_to(obj, device, *args, **kwargs): | ||
# Lists can be nested, so keep digging until we hit tensors | ||
if isinstance(obj, list): | ||
return [_recursive_to(o, device, *args, **kwargs) for o in obj] | ||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` | ||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): | ||
# cast and send to device | ||
return obj.to(*args, **kwargs) | ||
elif isinstance(obj, torch.Tensor) and device is not None: | ||
# only send to device, don't cast | ||
return obj.to(device=device) | ||
else: | ||
return obj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably be fixed on the parent class
What does this PR do?
Updates the conversion script