-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deep Fusion Modules #1338
Deep Fusion Modules #1338
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1338
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d1c599b with merge base b74f4b4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very clean overall. excited to see this come together with the model builders.
|
||
|
||
class FusionEmbedding(nn.Module): | ||
"""Fusion embedding supports training additional special tokens while keeping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you give an example of additional special tokens? and also an example of how a user might use this layer in their own model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to both of these - I think given some of this naming is our creation, I would vote for adding as much information to the doc strings as possible. This includes examples, paper pointers and sample code for how to use this
|
||
""" | ||
bs, seq_len = input.size() | ||
vocab_size = self.embedding.num_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could just save vocab_size as attribute directly in init?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each embedding keeps its own size already, it feels redundant for us to as well. Though I should probably add the property "num_embeddings" so this could be used like an embedding.
from torch import nn | ||
|
||
|
||
def register_fusion_module(module: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an example here would be helpful, especially to understand how this is used in combination with the other fusion classes
|
||
|
||
class FusionEmbedding(nn.Module): | ||
"""Fusion embedding supports training additional special tokens while keeping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to both of these - I think given some of this naming is our creation, I would vote for adding as much information to the doc strings as possible. This includes examples, paper pointers and sample code for how to use this
dtype = self.embedding.weight.dtype | ||
return torch.empty(bs, seq_len, self.dim, device=device, dtype=dtype) | ||
|
||
def forward(self, input: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the comment above, can we add some examples of input to the doc string?
vocab_size = self.embedding.num_embeddings | ||
|
||
mask = input < vocab_size | ||
tokens = torch.masked_select(input, mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the shape of this be BS x num_tokens? I was looking at https://pytorch.org/docs/stable/generated/torch.masked_select.html and I think that's right, but wasn't sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the shape comment. And yes it's [bs x num_token] where num_tokens = (input < vocab_size).sum()
|
||
mask = input < vocab_size | ||
tokens = torch.masked_select(input, mask) | ||
additional_tokens = torch.masked_select(input, ~mask) - vocab_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the tokenizer will always assign the special tokens IDs that are greater than the IDs in the vocabulary? Is that a general assumption or is this only applicable to Llama. If the latter then I would figure this out a bit. I can imagine cases where special/additional tokens might be from the first N tokens since these might be reserved during pretraining. Let me know if I misunderstand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think of this situation. The main reason to use this module is because you need to add extra tokens that the decoder training didn't account for because it wasn't trained to be fused with another model. I think if you had extra "capacity" in your original embedding table you could use those but you'd have to unfreeze and update it. What would generalizing this look like, defining a range of integers where the fusion embedding is used?
I feel it'd be easier to explain this consideration in the docstring instead of trying to generalize for that case for now, and update this module if needed.
from torch import nn, Tensor | ||
|
||
|
||
class FusionLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its slightly weird that the FusionLayer has an input param called fusion_layer
. Wondering if FusionLayerWrapper
is a better name
adapt the pre-trained encoder to the pre-trained decoder. | ||
|
||
Example:: | ||
>>> model = DeepFusionModel(LLama3(), CLIP()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's Llama3()
referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just meant to be a stand in for any decoder language model. I could use LLM() instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just use one of our builder functios so this isn't confusing?
) -> Union[Tensor, List[Tensor]]: | ||
""" | ||
Args: | ||
tokens (Tensor): input tensor with shape [b x s] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: in the classes above you're using [batch_size x seq_length] to call out the shapes. Stay consistent
out = embed(tokens) | ||
|
||
assert out.shape == (2, 10, dim) | ||
assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question: where's this value coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just grabbed from the first run. There's nothing specific that needs to be parity checked here, I just want to make sure that this value never changes in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do all these need to be in separate files? We don't separate out these in the LLM world and I think Soumith said at one point (and I agree) that too many redirects can be fatiguing to understanding what the code is actually doing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamping this since it overall looks good, but please address all of the outstanding comments!
@joecummings my guess is that each of these files are going to include many more classes, especially as we add support for early fusion models (extrapolating from the RFC). If this is true, it might make sense to keep these lightweight |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1338 +/- ##
===========================================
+ Coverage 27.24% 70.09% +42.84%
===========================================
Files 262 269 +7
Lines 12129 12598 +469
===========================================
+ Hits 3304 8830 +5526
+ Misses 8825 3768 -5057 ☔ View full report in Codecov by Sentry. |
Context
What is the purpose of this PR? Is it to
This implements modules from the Fusion Models RFC #1283 necessary for building Deep Fusion Models.
Changelog
This introduces the model_fusion folder in modules and the following moduels
Along with these modules there is an init.py, a new section in the modules api docs, and a unit test for each module/function added. No existing code was edited.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models