Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flamingo Model Components #1357

Merged
merged 18 commits into from
Sep 5, 2024
Merged

Conversation

pbontrager
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Reimplementation of #1150 based on refactor

Changelog

Adds new flamingo folder

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1357

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ce16f38 with merge base f437639 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 16, 2024
@codecov-commenter
Copy link

codecov-commenter commented Aug 20, 2024

Codecov Report

Attention: Patch coverage is 34.16149% with 106 lines in your changes missing coverage. Please review.

Project coverage is 27.59%. Comparing base (5155c4a) to head (20aa7f8).
Report is 1 commits behind head on main.

Files Patch % Lines
...torchtune/models/flamingo/test_flamingo_encoder.py 28.84% 37 Missing ⚠️
torchtune/models/flamingo/_component_builders.py 26.19% 31 Missing ⚠️
torchtune/models/flamingo/_encoder.py 35.48% 20 Missing ⚠️
torchtune/models/clip/_component_builders.py 30.00% 7 Missing ⚠️
...torchtune/models/flamingo/test_flamingo_decoder.py 50.00% 6 Missing ⚠️
torchtune/modules/feed_forward.py 16.66% 5 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (5155c4a) and HEAD (20aa7f8). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (5155c4a) HEAD (20aa7f8)
3 2
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1357       +/-   ##
===========================================
- Coverage   72.72%   27.59%   -45.14%     
===========================================
  Files         271      277        +6     
  Lines       12811    13040      +229     
===========================================
- Hits         9317     3598     -5719     
- Misses       3494     9442     +5948     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

if key.startswith("layer"):
new_key = key.replace("layer.", "")
local_key = key[len(prefix) :]
if local_key.startswith("layer"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be:

if local_key.startswith("layer.layer"):

) -> Tensor:
"""
Args:
x (Tensor): input tensor with shape [b x i x t x e x d]
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry this hurt my tiny brain

Suggested change
x (Tensor): input tensor with shape [b x i x t x e x d]
x (Tensor): input tensor with shape [b, i, t, e, d]

from the encoder. Each hidden state has the same shape as x.

Returns:
Tensor: output tensor of a sequence of embedings [b x s x d]
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Tensor: output tensor of a sequence of embedings [b x s x d]
Tensor: output tensor of a sequence of embedings [b, s, d]

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly a bunch of nits for now, will give a more thorough pass tomorrow though

@@ -1,6 +1,6 @@
from torchtune.models.clip._transforms import CLIPImageTransform

def _clip_vit_224_transform():
def clip_vit_224_transform():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be added to public API now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is zero reason for a builder to exist and not be public. It's not a utility but a model. The only reason it's not add to the docs yet is because there should be a model builder as well, not just the transform.

Comment on lines +27 to +32
stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``GroupedQueryAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo this is not the place to sell the design

torchtune/models/flamingo/_component_builders.py Outdated Show resolved Hide resolved
torchtune/models/flamingo/_component_builders.py Outdated Show resolved Hide resolved
num_heads (int): The number of attention heads in each transformer layer.
clip_embed_dim (int): The dimensionality of each patch embedding in CLIP.
clip_num_layers (int): The number of transformer layers.
clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the function it's List[int], no Optional

@@ -14,7 +16,7 @@ class FeedForward(nn.Module):
gate_proj (nn.Module): Projection from input dim to hidden dim, fed through activation
and multiplied by up_proj.
down_proj (nn.Module): Final projection to output dim.
up_proj (nn.Module): Projection from input dim to hidden dim, multiplied by
up_proj (Optional[nn.Module]): Projection from input dim to hidden dim, multiplied by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thing this class wasn't super well-defined anyways. But might be worth updating the docstring to explain the case of no up_proj

Comment on lines 27 to 34
class TestFlamingoVisionEncoder:
def test_flamingo_text_decoder_initialization(self, decoder_config):
# Attempt to instantiate the Flamingo text decoder
try:
decoder = flamingo_decoder(**decoder_config)
print("Flamingo text decoder instantiated successfully.")
except Exception as e:
pytest.fail(f"Failed to instantiate Flamingo text decoder: {str(e)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm being dense but what is the point of this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a placeholder that's been udpated.


def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False) -> FeedForward:
"""
Build the MLP layer associated with the clip model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the CLIP model, right? I feel like we can be a bit more explicit here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean CLIP ViT? I feel this is inline with naming for all our other model specific sub builders.

@@ -338,6 +339,7 @@ def __init__(
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.cur_pos = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this is about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tracks position for the TransformerDecoder since the layers also track it during decoding. This is already obsolete as Salaman is updating how cacheing handles position. But I'll leave this in until that is updated. This is basically a fix for the previous update around input_pos and kvcache.

Comment on lines +221 to +236
attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
q_norm=RMSNorm(dim=head_dim, eps=1e-05),
k_norm=RMSNorm(dim=head_dim, eps=1e-05),
pos_embeddings=None,
max_seq_len=max_seq_len,
is_causal=False,
attn_dropout=0.0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nbd yet but come LoRA time might want a builder for this kinda stuff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you're referring to

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore me

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry all I did was correct typos. But that just means the design is solid through and through, I have no major complaints there. Please do address my old comments too though. Preemptive stamp so that you're unblocked

in_dim=embed_dim,
hidden_dim=4 * embed_dim,
out_dim=embed_dim,
activation=activation(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if activation is truly a Callable as you've typed it this won't work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would this be called? A Class?

Comment on lines +294 to +295
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's our long-term plan for this? Will we continue to support both cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was in the TransformerLayer refactor. I left it in to support both because otherwise the builders get a lot less clean. But We could update that in the future.

tests/torchtune/models/flamingo/test_flamingo_decoder.py Outdated Show resolved Hide resolved
output.shape == expected_shape
), f"Expected shape {expected_shape}, but got {output.shape}"

assert_expected(output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it expected to be so close to zero? Whenever the tolerance is a couple orders of magnitude larger than the actual value it makes me a bit nervous

- i: number of images
- t: number of tiles (where a single image is broken into multiple tiles)
- e: number of embeds per tile (e.g. CLS embed + patch embeds, etc.)
- s: sequence length computed by i*t*e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it just equals i*t*e, right?

torchtune/models/flamingo/_component_builders.py Outdated Show resolved Hide resolved
Comment on lines +221 to +236
attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
q_norm=RMSNorm(dim=head_dim, eps=1e-05),
k_norm=RMSNorm(dim=head_dim, eps=1e-05),
pos_embeddings=None,
max_seq_len=max_seq_len,
is_causal=False,
attn_dropout=0.0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore me

Comment on lines +89 to +92
clip (nn.Module): CLIP encoder vision model
projection_head (nn.Module): projection_head that takes embeddings
with dimension encoder_dim as input and outputs embeddings of
size decoder_dim.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nbd to leave these both as general nn.Modules but they do have very specific signatures that make it hard to plug in any old nn.Module. Maybe point to the relevant CLIP and projection classes as an example or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This being inside of the flamingo folder it's not really meant to be reused for other purposes. I can just set the type to be the specific modules.

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Apply image decoding and transformations to the "images" field in the sample
and tokenizization to the "messages" field in the sample. Also returns the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and tokenizization to the "messages" field in the sample. Also returns the
and tokenization to the "messages" field in the sample. Also returns the

The extra text will still get tokenized as normal text, not as special tokens. Default is None.

Examples:
>>> model_transform = FlamingoTransform("/path/to/tokenizer.model", tile_size=256)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think you also need to provide patch_size in this example

from torchtune.modules.transformer import _get_clones


class FlamingoProjectionHead(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is Head the right term? I would think Head is something that attached on top of the hidden states of a transformer and not a full model. I thought we had referred to this as an Adapter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A projection heads have been around a long time and can vary a lot in architecture. The main point here is that it's learning a projection from the pre-trained encoder to the pretrained decoder.

torchtune/models/flamingo/_encoder.py Show resolved Hide resolved


class FlamingoProjectionHead(nn.Module):
"""Projection transformer to adapt the output of a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to see more details here, specifically on how this is used to map from encoder hidden dim to the decoder hidden dim in the cross attention layer

self,
layer: nn.Module,
num_layers: int,
output: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just hardcode this to nn.Linear? I'm wondering if it makes more sense for a user to configure encoder_dim -> decoder_dim rather than the output module directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can update this, I was just following our typical pattern but this would never need to be customized by a user anyway.


def forward(
self,
x: Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we've been using torch.Tensor everywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I lost that vote :(

torchtune/models/flamingo/_transform.py Outdated Show resolved Hide resolved

Args:
sample (Mapping[str, Any]): A sample with a "tokens", "mask",
"encoder_input" and "encoder_mask" field to feed directly into the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't expect encoder_input, encoder_mask to already be in sample, we should expect "images" though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do expect encoder_input to be in sample by this point, look at FlamingoTransform. I agree there is some questionable generalization here, but I think we should address that when it comes up. In summary, to allow unpacking of any arbitrary input for the encoder, we treat it as a dictionary.

@@ -25,7 +25,7 @@
"<|eom_id|>": 128008,
"<|eot_id|>": 128009,
"<|python_tag|>": 128010,
"<|image|>": 128011,
"<|image|>": 128256,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we verified that this the official image token id...? When I first added this it was 128011

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long story, but both numbers were supposed to be correct. One for finetuning and one for inference, but now it's just the one.

@@ -338,6 +339,7 @@ def __init__(
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.pos = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for generation? if so, a quick comment would be helpful

where length of list == number of images in sample
- tokens (List[int]): original tokens
- images (List[torch.Tensor]): original images
Mapping[str, Any]: sample with a new key encoder_mask, with a mask per image with shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we packaging multiple keys into encoder_input? also, you mention encoder_mask here but use encoder_input below. Would also be good to keep the bullets about tokens and images

pbontrager and others added 5 commits September 4, 2024 19:51
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
@pbontrager pbontrager merged commit 7920dc8 into pytorch:main Sep 5, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants