Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flamingo][multimodal] Vision encoder + text decoder #1150

Closed
wants to merge 68 commits into from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jul 8, 2024

Context

What is the purpose of this PR? Is it to

TODO:

  • clip transformer has to be replaced with TransformerSelfAttentionLayer
  • Confirm location and naming of all new transformer related modules
  • regression tests
  • parity check
  • docstrings
  • builder with text decoder + vision encoder (prob in another PR)

Changelog

torchtune/models/clip/_component_builders.py

  • Replaced pytorch native transformer with torchtunes version;

torchtune/models/flamingo/_encoders.py

  • FlamingoVisionAdapter: takes in clip embedding and outputs final vision embedding
  • FlamingoVisionEncoder: wrapper to call Clip + adapter

torchtune/models/flamingo/_component_builders.py

  • flamingo_vision_encoder: instantiates clip_vision_encoder + FlamingoVisionAdapter and returns FlamingoVisionEncoder
  • flamingo_text_decoder: instantiates llama3 + TransformerCrossAttentionLayer using FusionLayer and returns MMTransformerDecoder

torchtune/modules/feedforward.py

  • MLP: simple feed forward layer for the transformer MLP. Used by CLIP and Flamingo.

torchtune/modules/multimodal_transformer.py

  • TransformerSelfAttentionLayer: used in flamingo with gates. Used in CLIP as vanillas transformer.
  • TransformerCrossAttentionLayer: used in flamingo_text_decoder
  • MMTransformerDecoder: used in flamingo_text_decoder

torchtune/modules/attention.py

  • GroupedQueryAttention: used in flamingo_text_decoder, clip, flamingo_vision_encoder. IMPORTANT: llama3 in flamingo will use this, while llama3-text-only will use CausalSelfAttention

torchtune/modules/model_fusion.py

  • fusion_embed.py: used in flamingo_text_encoder
  • fusion_layer.py: used in flamingo_text_encoder

torchtune/modules/tanh_gate.py

  • TanhGate: used in flamingo vision adapter and cross encoder.

Test plan

  • Clip tests pass, except one regression.

  • Flamingo vision encoder shapes pass, need regression.

  • Flamingo text decoder instantiates, but no shape or regression tests yet.

  • Need tests for the individual new modules. But testing the models as whole should work for now.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)

  • add unit tests for any new functionality

  • update docstrings for any new or updated methods or classes

  • run unit tests via pytest tests

  • run recipe tests via pytest tests -m integration_test

  • manually run any new or modified recipes with sufficient proof of correctness

    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented Jul 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1150

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f5d72b4 with merge base 069b12b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 8, 2024
torchtune/models/flamingo/_encoders.py Outdated Show resolved Hide resolved
torchtune/models/flamingo/_encoders.py Outdated Show resolved Hide resolved
torchtune/models/flamingo/_encoders.py Outdated Show resolved Hide resolved

# projection
x = x.view(bsz, n_ims, n_tiles, n_tokens, embed_dim)
x = torch.cat([x, hidden_states], dim=-1)
Copy link
Contributor

@ebsmothers ebsmothers Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add shape comment here

torchtune/models/flamingo/_encoders.py Show resolved Hide resolved
@@ -227,3 +227,233 @@ def forward(
# reshape the output to be the same shape as the input
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.output_proj(output)


class GroupedQueryAttention(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why this is in modules/attention.py but other stuff goes in models/flamingo. I would suggest moving this under models/flamingo too to be consistent unless there's a clear reason not to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GQA isn't specific to Flamingo, for example, it's referenced in Meta's MobileLLM paper. https://arxiv.org/abs/2402.14905

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So our "primary" attention implementation already supports GQA and MQA [ref]. I guess the question here is if this entire module is specific to flamingo, should this just reside within that folder till we figure out how to merge back? That said, I'll raise the same concern as I did in Philip's PR that GQA isn't a great name for this module. I'd prefer something like MMMultiHeadAttention or similar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah our naming here is kinda unclear. Like we already support GQA but via our CausalSelfAttention class. So any new functionality here is mainly to support cross-attention. So as is the naming of these two classes is not actually conveying how they differ. My two cents is that we should just call them MultiHeadAttention and MultiHeadCrossAttention to get this point across, but I think the renaming of CausalSelfAttention can be saved for another day.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will rename it after we are 100% sure :P

Comment on lines +34 to +35
attn_scale: Optional[nn.Module] = None,
mlp_scale: Optional[nn.Module] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to keep harping on this, but I still don't understand why we don't just provide separate versions of self-attention and MLP modules with scaling (see e.g. here for what the MLP would look like). Then we can provide different builders for TransformerSelfAttentionLayer with and without scaling and users don't have to try and figure out attn_scale and mlp_scale mean

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that both ways work. I like your idea, but i also think its convenient if the module already provides it. I guess that one benefits of this logic being in the transformer module is that we dont have to have 2x the implementation of every scale and MLP module, one gated and one that isnt. The con is that this module gets a bit bloated.

from torchtune.modules import GroupedQueryAttention


class TransformerSelfAttentionLayer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to just rename this tp have "Multimodal" in the prefix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also used in clip. I think that we should only use MM for modules that touch at the same time text and image. This is not the case for this module. What do you think?

@felipemello1 felipemello1 changed the title [Flamingo][multimodal] Vision encoder for flamingo [Flamingo][multimodal] Vision encoder + text decoder Jul 10, 2024
@codecov-commenter
Copy link

codecov-commenter commented Jul 10, 2024

Codecov Report

Attention: Patch coverage is 78.38542% with 83 lines in your changes missing coverage. Please review.

Project coverage is 69.14%. Comparing base (06a125e) to head (f5d72b4).
Report is 4 commits behind head on main.

Files Patch % Lines
torchtune/modules/multimodal_transformer.py 52.74% 43 Missing ⚠️
torchtune/modules/model_fusion/fusion_embed.py 41.37% 17 Missing ⚠️
torchtune/modules/attention.py 79.31% 12 Missing ⚠️
torchtune/modules/model_fusion/fusion_layer.py 55.55% 8 Missing ⚠️
...tune/models/flamingo/test_flamingo_text_decoder.py 83.33% 2 Missing ⚠️
...ne/models/flamingo/test_flamingo_vision_encoder.py 97.95% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1150       +/-   ##
===========================================
+ Coverage   26.76%   69.14%   +42.38%     
===========================================
  Files         205      225       +20     
  Lines        9301    10096      +795     
===========================================
+ Hits         2489     6981     +4492     
+ Misses       6812     3115     -3697     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@facebook-github-bot
Copy link

@tarun292 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link

@tarun292 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link

@tarun292 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants