-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flamingo Model Components #1357
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1357
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ce16f38 with merge base f437639 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1357 +/- ##
===========================================
- Coverage 72.72% 27.59% -45.14%
===========================================
Files 271 277 +6
Lines 12811 13040 +229
===========================================
- Hits 9317 3598 -5719
- Misses 3494 9442 +5948 ☔ View full report in Codecov by Sentry. |
if key.startswith("layer"): | ||
new_key = key.replace("layer.", "") | ||
local_key = key[len(prefix) :] | ||
if local_key.startswith("layer"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be:
if local_key.startswith("layer.layer"):
) -> Tensor: | ||
""" | ||
Args: | ||
x (Tensor): input tensor with shape [b x i x t x e x d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry this hurt my tiny brain
x (Tensor): input tensor with shape [b x i x t x e x d] | |
x (Tensor): input tensor with shape [b, i, t, e, d] |
from the encoder. Each hidden state has the same shape as x. | ||
|
||
Returns: | ||
Tensor: output tensor of a sequence of embedings [b x s x d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor: output tensor of a sequence of embedings [b x s x d] | |
Tensor: output tensor of a sequence of embedings [b, s, d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly a bunch of nits for now, will give a more thorough pass tomorrow though
@@ -1,6 +1,6 @@ | |||
from torchtune.models.clip._transforms import CLIPImageTransform | |||
|
|||
def _clip_vit_224_transform(): | |||
def clip_vit_224_transform(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be added to public API now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is zero reason for a builder to exist and not be public. It's not a utility but a model. The only reason it's not add to the docs yet is because there should be a model builder as well, not just the transform.
stitch these building blocks into higher-level components. This design has | ||
two benefits: | ||
- The building blocks themselves are very flexible. For example, ``GroupedQueryAttention`` | ||
can take either nn.Linear or nn.LoRALinear for ``q_proj``. | ||
- Builder functions expose a set of configurable params which keep the constructors of | ||
the building blocks simple. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo this is not the place to sell the design
num_heads (int): The number of attention heads in each transformer layer. | ||
clip_embed_dim (int): The dimensionality of each patch embedding in CLIP. | ||
clip_num_layers (int): The number of transformer layers. | ||
clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the function it's List[int], no Optional
@@ -14,7 +16,7 @@ class FeedForward(nn.Module): | |||
gate_proj (nn.Module): Projection from input dim to hidden dim, fed through activation | |||
and multiplied by up_proj. | |||
down_proj (nn.Module): Final projection to output dim. | |||
up_proj (nn.Module): Projection from input dim to hidden dim, multiplied by | |||
up_proj (Optional[nn.Module]): Projection from input dim to hidden dim, multiplied by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good thing this class wasn't super well-defined anyways. But might be worth updating the docstring to explain the case of no up_proj
class TestFlamingoVisionEncoder: | ||
def test_flamingo_text_decoder_initialization(self, decoder_config): | ||
# Attempt to instantiate the Flamingo text decoder | ||
try: | ||
decoder = flamingo_decoder(**decoder_config) | ||
print("Flamingo text decoder instantiated successfully.") | ||
except Exception as e: | ||
pytest.fail(f"Failed to instantiate Flamingo text decoder: {str(e)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm being dense but what is the point of this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a placeholder that's been udpated.
|
||
def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, quantize_base: bool = False) -> FeedForward: | ||
""" | ||
Build the MLP layer associated with the clip model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the CLIP model, right? I feel like we can be a bit more explicit here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean CLIP ViT? I feel this is inline with naming for all our other model specific sub builders.
torchtune/modules/transformer.py
Outdated
@@ -338,6 +339,7 @@ def __init__( | |||
self.num_heads = num_heads | |||
self.head_dim = head_dim | |||
self.causal_mask = None | |||
self.cur_pos = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what this is about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tracks position for the TransformerDecoder since the layers also track it during decoding. This is already obsolete as Salaman is updating how cacheing handles position. But I'll leave this in until that is updated. This is basically a fix for the previous update around input_pos and kvcache.
attn = MultiHeadAttention( | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
head_dim=head_dim, | ||
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), | ||
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
output_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
q_norm=RMSNorm(dim=head_dim, eps=1e-05), | ||
k_norm=RMSNorm(dim=head_dim, eps=1e-05), | ||
pos_embeddings=None, | ||
max_seq_len=max_seq_len, | ||
is_causal=False, | ||
attn_dropout=0.0, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nbd yet but come LoRA time might want a builder for this kinda stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you're referring to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry all I did was correct typos. But that just means the design is solid through and through, I have no major complaints there. Please do address my old comments too though. Preemptive stamp so that you're unblocked
in_dim=embed_dim, | ||
hidden_dim=4 * embed_dim, | ||
out_dim=embed_dim, | ||
activation=activation(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if activation is truly a Callable
as you've typed it this won't work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would this be called? A Class?
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when | ||
layers is not a list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's our long-term plan for this? Will we continue to support both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was in the TransformerLayer refactor. I left it in to support both because otherwise the builders get a lot less clean. But We could update that in the future.
output.shape == expected_shape | ||
), f"Expected shape {expected_shape}, but got {output.shape}" | ||
|
||
assert_expected(output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it expected to be so close to zero? Whenever the tolerance is a couple orders of magnitude larger than the actual value it makes me a bit nervous
- i: number of images | ||
- t: number of tiles (where a single image is broken into multiple tiles) | ||
- e: number of embeds per tile (e.g. CLS embed + patch embeds, etc.) | ||
- s: sequence length computed by i*t*e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it just equals i*t*e, right?
attn = MultiHeadAttention( | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
head_dim=head_dim, | ||
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), | ||
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
output_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
q_norm=RMSNorm(dim=head_dim, eps=1e-05), | ||
k_norm=RMSNorm(dim=head_dim, eps=1e-05), | ||
pos_embeddings=None, | ||
max_seq_len=max_seq_len, | ||
is_causal=False, | ||
attn_dropout=0.0, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore me
clip (nn.Module): CLIP encoder vision model | ||
projection_head (nn.Module): projection_head that takes embeddings | ||
with dimension encoder_dim as input and outputs embeddings of | ||
size decoder_dim. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nbd to leave these both as general nn.Modules but they do have very specific signatures that make it hard to plug in any old nn.Module. Maybe point to the relevant CLIP and projection classes as an example or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This being inside of the flamingo folder it's not really meant to be reused for other purposes. I can just set the type to be the specific modules.
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | ||
""" | ||
Apply image decoding and transformations to the "images" field in the sample | ||
and tokenizization to the "messages" field in the sample. Also returns the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and tokenizization to the "messages" field in the sample. Also returns the | |
and tokenization to the "messages" field in the sample. Also returns the |
The extra text will still get tokenized as normal text, not as special tokens. Default is None. | ||
|
||
Examples: | ||
>>> model_transform = FlamingoTransform("/path/to/tokenizer.model", tile_size=256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think you also need to provide patch_size in this example
from torchtune.modules.transformer import _get_clones | ||
|
||
|
||
class FlamingoProjectionHead(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is Head the right term? I would think Head is something that attached on top of the hidden states of a transformer and not a full model. I thought we had referred to this as an Adapter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A projection heads have been around a long time and can vary a lot in architecture. The main point here is that it's learning a projection from the pre-trained encoder to the pretrained decoder.
|
||
|
||
class FlamingoProjectionHead(nn.Module): | ||
"""Projection transformer to adapt the output of a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like to see more details here, specifically on how this is used to map from encoder hidden dim to the decoder hidden dim in the cross attention layer
self, | ||
layer: nn.Module, | ||
num_layers: int, | ||
output: nn.Module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we just hardcode this to nn.Linear? I'm wondering if it makes more sense for a user to configure encoder_dim -> decoder_dim rather than the output module directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can update this, I was just following our typical pattern but this would never need to be customized by a user anyway.
|
||
def forward( | ||
self, | ||
x: Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we've been using torch.Tensor
everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lost that vote :(
|
||
Args: | ||
sample (Mapping[str, Any]): A sample with a "tokens", "mask", | ||
"encoder_input" and "encoder_mask" field to feed directly into the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't expect encoder_input, encoder_mask to already be in sample, we should expect "images" though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do expect encoder_input to be in sample by this point, look at FlamingoTransform. I agree there is some questionable generalization here, but I think we should address that when it comes up. In summary, to allow unpacking of any arbitrary input for the encoder, we treat it as a dictionary.
@@ -25,7 +25,7 @@ | |||
"<|eom_id|>": 128008, | |||
"<|eot_id|>": 128009, | |||
"<|python_tag|>": 128010, | |||
"<|image|>": 128011, | |||
"<|image|>": 128256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have we verified that this the official image token id...? When I first added this it was 128011
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long story, but both numbers were supposed to be correct. One for finetuning and one for inference, but now it's just the one.
@@ -338,6 +339,7 @@ def __init__( | |||
self.num_heads = num_heads | |||
self.head_dim = head_dim | |||
self.causal_mask = None | |||
self.pos = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this for generation? if so, a quick comment would be helpful
where length of list == number of images in sample | ||
- tokens (List[int]): original tokens | ||
- images (List[torch.Tensor]): original images | ||
Mapping[str, Any]: sample with a new key encoder_mask, with a mask per image with shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we packaging multiple keys into encoder_input? also, you mention encoder_mask here but use encoder_input below. Would also be good to keep the bullets about tokens and images
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Context
What is the purpose of this PR? Is it to
Reimplementation of #1150 based on refactor
Changelog
Adds new flamingo folder
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models