-
Notifications
You must be signed in to change notification settings - Fork 536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flamingo][multimodal] Vision encoder + text decoder #1150
Conversation
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
…o flamingo_encoder
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1150
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f5d72b4 with merge base 069b12b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
# projection | ||
x = x.view(bsz, n_ims, n_tiles, n_tokens, embed_dim) | ||
x = torch.cat([x, hidden_states], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add shape comment here
@@ -227,3 +227,233 @@ def forward( | |||
# reshape the output to be the same shape as the input | |||
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) | |||
return self.output_proj(output) | |||
|
|||
|
|||
class GroupedQueryAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me why this is in modules/attention.py but other stuff goes in models/flamingo. I would suggest moving this under models/flamingo too to be consistent unless there's a clear reason not to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GQA isn't specific to Flamingo, for example, it's referenced in Meta's MobileLLM paper. https://arxiv.org/abs/2402.14905
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So our "primary" attention implementation already supports GQA and MQA [ref]. I guess the question here is if this entire module is specific to flamingo, should this just reside within that folder till we figure out how to merge back? That said, I'll raise the same concern as I did in Philip's PR that GQA isn't a great name for this module. I'd prefer something like MMMultiHeadAttention or similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah our naming here is kinda unclear. Like we already support GQA but via our CausalSelfAttention
class. So any new functionality here is mainly to support cross-attention. So as is the naming of these two classes is not actually conveying how they differ. My two cents is that we should just call them MultiHeadAttention
and MultiHeadCrossAttention
to get this point across, but I think the renaming of CausalSelfAttention
can be saved for another day.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will rename it after we are 100% sure :P
attn_scale: Optional[nn.Module] = None, | ||
mlp_scale: Optional[nn.Module] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to keep harping on this, but I still don't understand why we don't just provide separate versions of self-attention and MLP modules with scaling (see e.g. here for what the MLP would look like). Then we can provide different builders for TransformerSelfAttentionLayer
with and without scaling and users don't have to try and figure out attn_scale
and mlp_scale
mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that both ways work. I like your idea, but i also think its convenient if the module already provides it. I guess that one benefits of this logic being in the transformer module is that we dont have to have 2x the implementation of every scale and MLP module, one gated and one that isnt. The con is that this module gets a bit bloated.
from torchtune.modules import GroupedQueryAttention | ||
|
||
|
||
class TransformerSelfAttentionLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to just rename this tp have "Multimodal" in the prefix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also used in clip. I think that we should only use MM for modules that touch at the same time text and image. This is not the case for this module. What do you think?
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1150 +/- ##
===========================================
+ Coverage 26.76% 69.14% +42.38%
===========================================
Files 205 225 +20
Lines 9301 10096 +795
===========================================
+ Hits 2489 6981 +4492
+ Misses 6812 3115 -3697 ☔ View full report in Codecov by Sentry. |
@tarun292 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@tarun292 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@tarun292 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Context
What is the purpose of this PR? Is it to
add a new feature
fix a bug
update tests and/or documentation
other (please add here)
Added flamingo vision encoder
Ported text decoder from: Fused SelfAttention and Cross Attention Decoder #1146
Updated CLIP attention module
TODO:
Changelog
torchtune/models/clip/_component_builders.py
torchtune/models/flamingo/_encoders.py
torchtune/models/flamingo/_component_builders.py
torchtune/modules/feedforward.py
torchtune/modules/multimodal_transformer.py
torchtune/modules/attention.py
torchtune/modules/model_fusion.py
torchtune/modules/tanh_gate.py
Test plan
Clip tests pass, except one regression.
Flamingo vision encoder shapes pass, need regression.
Flamingo text decoder instantiates, but no shape or regression tests yet.
Need tests for the individual new modules. But testing the models as whole should work for now.
run pre-commit hooks and linters (make sure you've first installed via
pre-commit install
)add unit tests for any new functionality
update docstrings for any new or updated methods or classes
run unit tests via
pytest tests
run recipe tests via
pytest tests -m integration_test
manually run any new or modified recipes with sufficient proof of correctness