Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision cross attention mask transform #1141

Merged
merged 11 commits into from
Jul 10, 2024
Merged

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Jul 3, 2024

Context

Multimodal models that use cross-attention layers in the text LLM backbone to attend to outputs of a vision encoder on the images require a cross-attention mask to ensure the correct text tokens attend to the right images. For multiple, interleaved images in a single sample, we follow the approach in Flamingo (Fig. 7 of https://arxiv.org/pdf/2204.14198) where text tokens after an image attend to that particular image (more details in docstrings). To create this mask, this PR adds a transform class that takes a sample dictionary (i.e., during data preprocessing in one of torchtune's dataset classes), computes this cross-attention mask and adds it to the dictionary.

Handling variable number of tiles and number of images across samples will be done in the batch collator in a follow-up PR. For now, we return the masks as a list of tensors, one mask per image, which can have varied n_tiles.

Changelog

  • Add base Transform interface
  • Add VisionCrossAttentionMask
  • Add unit tests for VisionCrossAttentionMask

Test plan

pytest tests/torchtune/modules/transforms/test_transforms.py

Docs

image

Copy link

pytorch-bot bot commented Jul 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1141

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bead59a with merge base 37636a8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 3, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks great! I left a few comments more around naming, since tiles/patches/image is so confusing in general. I think that this class would benefit from a visual small example too.

regarding the interval parts, IMO we should try to make the operations as explicit as possible. Instead of tok1, tok2; i; vision_mask[0], we should try to give them names, like idx_start, idx_end, etc.

participate in cross-attention with an image token will show True in the mask
and follow these rules:
1) Text tokens immediately following the image token up until the next image token
2) Consecutive image tokens attend to all subsequent text tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it more visual? Something like this: https://github.com/huggingface/transformers/blob/60bb571e993b7d73257fb64044726b569fef9403/src/transformers/models/llava_next/modeling_llava_next.py#L446

Or a link to the paper + page where they have an image for it


class CrossAttentionMask(Transform):
"""
Computes the cross-attention mask for text + image inputs. Text tokens that
Copy link
Contributor

@felipemello1 felipemello1 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this CrossAttention is specific for text + images, should we indicate it in the name? Something like MultimodalCrossAttentionMask or VisionTextCrossAttentionMask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point. Will rename

text sequence.

Args:
num_patches (int): Number of patches per image, excluding class token.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should add a link to modules/VisionTransformer for a in-depth explanation of what num_patches mean. For better clarity, would it make sense to rename it num_patches_per_tile, since later we multiply it by n_tiles?

If we say "number of patches per image", it may be confusing, because an image can have a variable number of patches.

later on you say:

single_image_seq_len = n_tiles * self.num_patches + 1
image_seq_len = single_image_seq_len * n_img

So tile != image. Image is a set of tiles.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry, probably my shallow understanding of patches/images/tiles. What I intended was num_patches per tile. If it makes sense I'd like to keep the name consistent with your vision transformer (either patch_grid_size or patch_size maybe?), whichever parameter you use to compute num patches.

I also assumed at this point tiles is padded to the max in all the images. Is this incorrect? Where does the padding happen?

Copy link
Contributor

@felipemello1 felipemello1 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patches_per_tile is a fixed size, and its calculated as (tile_size // patch_size)**2.

What I did in VisionTransformer was to ask the user to pass tile_size and patch_size, and I calculated it for them. The VisionTransformer has a helper function that saves this value: https://github.com/felipemello1/torchtune/blob/f683812626ad4559464840112ddce516487bea5c/torchtune/modules/vision_transformer.py#L249

Maybe get it from the model, or ask for tile_size and patch_size, to avoid user confusion?

pass


class Compose(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the torchvision compose has a different behavior, I wonder if it makes sense to change Compose to something else, so users dont get confused with tv.Compose. Maybe "ComposeTransforms"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about Pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding, the main difference with torchvision compose is that we support multiple inputs and multiple outputs here? Can we not just use torchvision compose with a single dict?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried naming something as Pipeline, as Kartikay said it would confuse people, because it is also used by other libraries :P. I guess sklearn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers our Compose needs to have a slightly different forward signature to unfold dictionary inputs. From torchvision:

def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

but to avoid confusion, I agree should name it something else. Just haven't figured out what yet

"""
Returns a list of tuples of the form (start, end) where start is the index
of the current image token and end is the index of the next image token, exclusive.
If the image token attends until the end of the sequence, end will be -1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we add Args:,Returns, Examples?

torchtune/modules/transforms/_transforms.py Show resolved Hide resolved

def __call__(self, *, tokens, images, **kwargs):
# We are still at sample level pre-collating
n_img, n_tiles, _, _, _ = images.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe add a comment with the other dimensions, so we know what they are, but keep the "_", so we know they are not used?

You said "# We are still at sample level pre-collating"
So is n_img == bsz? If so, for consistency with VisionTransformer, should we rename it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah please add type and shape info for arguments to __call__

# We are still at sample level pre-collating
n_img, n_tiles, _, _, _ = images.shape
text_seq_len = len(tokens)
single_image_seq_len = n_tiles * self.num_patches + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add comment explaining that +1 is for CLS, if thats the case

Comment on lines 136 to 138
image_num
* single_image_seq_len : (image_num + 1)
* single_image_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: split line differently if linter allows it, as written this is confusing

* single_image_seq_len,
] = True

kwargs.update({"encoder_mask": mask, "tokens": tokens, "images": images})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to also update with tokens and images? Isn't this a no-op for those args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since they are explicit keyword args, they get unfolded from kwargs and you have to add them back in

@RdoubleA RdoubleA changed the title [WIP] Add basic transforms for multimodal Vision cross attention mask transform Jul 9, 2024
@RdoubleA RdoubleA marked this pull request as ready for review July 9, 2024 00:15
torchtune/modules/transforms/_transforms.py Outdated Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Outdated Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Outdated Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Show resolved Hide resolved
mask[start:end, :] = True
masks.append(mask)

kwargs.update({"encoder_mask": masks, "tokens": tokens, "images": images})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's just me but I find this whole kwargs pattern kinda unintuitive. Like really we are covertly using this transform to also add its inputs into the dictionary (does that mean that it doesn't already contain them?). Maybe I need to see the callsite, but I don't understand why we can't just have e.g.

sample: Dict[str, Any]
mask_transform = VisionCrossAttentionMask(...)
sample["encoder_mask"] = mask_transform(sample["tokens"], sample["images"])

Alternatively if you want to go more of the compose route:

class VisionCrossAttentionMask:
	def __call__(sample: Dict[str, Any]) -> Dict[str, Any]:
		# just key into sample directly in the transform
		sample["encoder_mask"] = construct_mask(sample["tokens"], sample["images"])
		return sample

mask_transform = VisionCrossAttentionMask(...)
sample = mask_transform(sample)

But rn we are kinda halfway in between the two and it feels clunky to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pbontrager who'sthought about this a lot more

We were between two approaches, the kwargs approach here and the sample approach you described where you key into the fields you need for the transform. One issue with that was this implicit assumption that token ids are always under "tokens", image tensors are always under "images", etc., so we thought we may have to provide these keys as attributes to the transform class. Then the key management gets a bit messy to handle and error prone.

Although, thinking about this more, the collator will also implicitly assume the keys "tokens", "labels", "images" etc will be present. If we're consistent with these across the library, and the fact that the number of these keys in the sample dict will remain quite small ~5-10 max, it might be fine to keep these assumptions, or do something similar to checkpointing with model key, optim key, etc as constants.

I don't have a strong opinion either way, although it would be nice to keep a forward signature of just sample.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah imo you're gonna need the keys somewhere anyways, better to just be explicit about it than do this weird back and forth between dict values and standalone arguments.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly nits about naming/docs. Feel free to ignore.

tests/torchtune/modules/transforms/test_transforms.py Outdated Show resolved Hide resolved
tests/torchtune/modules/transforms/test_transforms.py Outdated Show resolved Hide resolved
tests/torchtune/modules/transforms/test_transforms.py Outdated Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Outdated Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Outdated Show resolved Hide resolved
torchtune/modules/transforms/_transforms.py Outdated Show resolved Hide resolved
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

@RdoubleA RdoubleA merged commit bbc48e0 into pytorch:main Jul 10, 2024
29 checks passed
@RdoubleA RdoubleA deleted the mm_transforms branch July 10, 2024 00:17
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants