Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PixtralLarge] Update Pixtral conversion script to support large format! #34801

Merged
merged 46 commits into from
Jan 8, 2025

Conversation

ArthurZucker
Copy link
Collaborator

What does this PR do?

Updates the conversion script

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1 Rocketknight1 force-pushed the pixtral-large-script branch 2 times, most recently from aba20bf to 24d9ee5 Compare December 23, 2024 12:04
@Rocketknight1 Rocketknight1 marked this pull request as ready for review December 23, 2024 16:11
@Rocketknight1
Copy link
Member

Rocketknight1 commented Dec 23, 2024

This should be just about ready! Quick summary of the changes:

  • Made sure eps values and activations were handled correctly during conversion
  • Make sure the tokenizer gets special tokens assigned correctly during conversion
  • Make biases in the multimodal processor a config flag (this is enabled in Pixtral-12B but not in Pixtral-Large)
  • BatchMixFeature.to() was buggy when the input was a nested list (changes pulled in from Fix case of nested tensors in BatchMixFeature #35063)
  • PixtralProcessor made some strange assumptions when lists are passed (changes pulled in from Fix the structure of images in PixtralProcessor #35107)
  • Some float32 upcasts in Pixtral attention to match the behaviour of the vLLM reference code (vLLM uses xformers, which has a custom kernel that does those computations internally in float32)
  • The Pixtral template needed a lot of rewrites
    • System message handling including datetime
    • Undocumented behaviour in mistral-common: When a message has exactly one text chunk, plus one or more images, then the text is moved to the end, after the image tokens, even if this is not the order of the chunks passed in by the user. When there are multiple text chunks, we keep the order from the message. If we don't get this exactly right then model generations are garbage.

TODO:

  • Should we use xformers instead of manual float32 attention? It would be more accurate + faster, but would add a dependency to the model.
  • Make sure the conversion script still works for older Pixtral-12B.
  • Make sure @zucchini-nlp is okay with the changes in the Processor.

Comment on lines +70 to +83
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: The previous BatchFeature.to() actually flattened the structure of nested inputs, which created several bugs! This fix preserves nested structure

Comment on lines +208 to +222
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
# If there's a single sample, the image must belong to it
images = [[images]]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and is_image_or_image_url(images[0]):
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
# If there's a single sample, all images must belong to it
images = [images]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: Previously there were a lot of edge cases when users passed a single list of images. In some cases, the processor interpreted this as one image per sample rather than a list of images for one sample. This code avoids these error-prone inferences.

Comment on lines -494 to +506
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
if len(pixel_values) > 1:
raise ValueError("Batching/padding not supported yet!")
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]

# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
patch_embeds = self.ln_pre(patch_embeds)

# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
).to(self.device)

position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: These changes are here to handle images being passed in as a list of lists now. Previously, images were passed in as a flat list even though the processor output a list of lists. The only reason this didn't cause an error was because the bug in BatchFeature.to() silently fixed the list structure and made it match the modeling code 😓

@Rocketknight1
Copy link
Member

This should be ready for final review @ArthurZucker! I did ablation testing and reverted some of the dtype changes in modeling_pixtral.py - the results seem good without them and performance/memory improves.

Copy link
Collaborator Author

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's roll! A todo is to add another test for the new model 😉 Good to go otherwise

Comment on lines +67 to +79
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably be fixed on the parent class

@ArthurZucker ArthurZucker merged commit 3f483be into main Jan 8, 2025
18 checks passed
@ArthurZucker ArthurZucker deleted the pixtral-large-script branch January 8, 2025 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants