Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep Fusion Modules #1338

Merged
merged 10 commits into from
Aug 16, 2024
Merged

Deep Fusion Modules #1338

merged 10 commits into from
Aug 16, 2024

Conversation

pbontrager
Copy link
Contributor

@pbontrager pbontrager commented Aug 14, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This implements modules from the Fusion Models RFC #1283 necessary for building Deep Fusion Models.

Changelog

This introduces the model_fusion folder in modules and the following moduels

  • fusion_models.DeepFusionModel
  • fusion_layer.FusionLayer
  • fusion_embed.FusionEmbedding
  • fusion_utils.regsiter_fusion_module

Along with these modules there is an init.py, a new section in the modules api docs, and a unit test for each module/function added. No existing code was edited.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1338

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d1c599b with merge base b74f4b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 14, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very clean overall. excited to see this come together with the model builders.



class FusionEmbedding(nn.Module):
"""Fusion embedding supports training additional special tokens while keeping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you give an example of additional special tokens? and also an example of how a user might use this layer in their own model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to both of these - I think given some of this naming is our creation, I would vote for adding as much information to the doc strings as possible. This includes examples, paper pointers and sample code for how to use this

torchtune/modules/model_fusion/fusion_embed.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/fusion_embed.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/fusion_embed.py Outdated Show resolved Hide resolved

"""
bs, seq_len = input.size()
vocab_size = self.embedding.num_embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just save vocab_size as attribute directly in init?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each embedding keeps its own size already, it feels redundant for us to as well. Though I should probably add the property "num_embeddings" so this could be used like an embedding.

torchtune/modules/model_fusion/fusion_models.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/fusion_models.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/fusion_models.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/fusion_models.py Outdated Show resolved Hide resolved
from torch import nn


def register_fusion_module(module: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an example here would be helpful, especially to understand how this is used in combination with the other fusion classes

torchtune/modules/model_fusion/__init__.py Outdated Show resolved Hide resolved


class FusionEmbedding(nn.Module):
"""Fusion embedding supports training additional special tokens while keeping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to both of these - I think given some of this naming is our creation, I would vote for adding as much information to the doc strings as possible. This includes examples, paper pointers and sample code for how to use this

dtype = self.embedding.weight.dtype
return torch.empty(bs, seq_len, self.dim, device=device, dtype=dtype)

def forward(self, input: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the comment above, can we add some examples of input to the doc string?

vocab_size = self.embedding.num_embeddings

mask = input < vocab_size
tokens = torch.masked_select(input, mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the shape of this be BS x num_tokens? I was looking at https://pytorch.org/docs/stable/generated/torch.masked_select.html and I think that's right, but wasn't sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the shape comment. And yes it's [bs x num_token] where num_tokens = (input < vocab_size).sum()


mask = input < vocab_size
tokens = torch.masked_select(input, mask)
additional_tokens = torch.masked_select(input, ~mask) - vocab_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the tokenizer will always assign the special tokens IDs that are greater than the IDs in the vocabulary? Is that a general assumption or is this only applicable to Llama. If the latter then I would figure this out a bit. I can imagine cases where special/additional tokens might be from the first N tokens since these might be reserved during pretraining. Let me know if I misunderstand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think of this situation. The main reason to use this module is because you need to add extra tokens that the decoder training didn't account for because it wasn't trained to be fused with another model. I think if you had extra "capacity" in your original embedding table you could use those but you'd have to unfreeze and update it. What would generalizing this look like, defining a range of integers where the fusion embedding is used?

I feel it'd be easier to explain this consideration in the docstring instead of trying to generalize for that case for now, and update this module if needed.

from torch import nn, Tensor


class FusionLayer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its slightly weird that the FusionLayer has an input param called fusion_layer. Wondering if FusionLayerWrapper is a better name

adapt the pre-trained encoder to the pre-trained decoder.

Example::
>>> model = DeepFusionModel(LLama3(), CLIP())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's Llama3() referring to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just meant to be a stand in for any decoder language model. I could use LLM() instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just use one of our builder functios so this isn't confusing?

) -> Union[Tensor, List[Tensor]]:
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in the classes above you're using [batch_size x seq_length] to call out the shapes. Stay consistent

out = embed(tokens)

assert out.shape == (2, 10, dim)
assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: where's this value coming from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just grabbed from the first run. There's nothing specific that needs to be parity checked here, I just want to make sure that this value never changes in the future.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do all these need to be in separate files? We don't separate out these in the LLM world and I think Soumith said at one point (and I agree) that too many redirects can be fatiguing to understanding what the code is actually doing.

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping this since it overall looks good, but please address all of the outstanding comments!

@kartikayk
Copy link
Contributor

@joecummings my guess is that each of these files are going to include many more classes, especially as we add support for early fusion models (extrapolating from the RFC). If this is true, it might make sense to keep these lightweight

@codecov-commenter
Copy link

codecov-commenter commented Aug 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 70.09%. Comparing base (d0b89e2) to head (bb9d440).
Report is 2 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1338       +/-   ##
===========================================
+ Coverage   27.24%   70.09%   +42.84%     
===========================================
  Files         262      269        +7     
  Lines       12129    12598      +469     
===========================================
+ Hits         3304     8830     +5526     
+ Misses       8825     3768     -5057     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@pbontrager pbontrager merged commit 67f6a06 into pytorch:main Aug 16, 2024
20 checks passed
@pbontrager pbontrager deleted the deep_fusion branch August 16, 2024 17:26
@RdoubleA RdoubleA mentioned this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants