-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resizable image positional embeddings #1695
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1695
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 59f1996 with merge base 4e69db8 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -0,0 +1,104 @@ | |||
# Config for single device full finetuning in full_finetune_single_device.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of adding this? Just so that we support the non-instruct version of the model? I'm a bit confused cause I thought one big diff with instruct-tuned vs not is the extra trainable special tokens on the text size, which this PR doesn't address
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is just for testing. I will remove it before the PR is ready
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regarding the special token, thats a question for @pbontrager . I am not sure.
tile_pos_emb_test_cases = [ | ||
{ | ||
"tgt_num_tiles": 1, | ||
# [max_num_tiles, max_num_tiles, -1, embed_dim] -> (2, 2, 2, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't fully follow these comments in each test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just trying to help the reader understand the dimensions provided. I can make it better. -1 is because the actual pos embedding has dim=1 there, but when i created the tests, i created with 2.
@@ -100,23 +100,314 @@ def __init__( | |||
|
|||
self.gate = nn.Parameter(torch.zeros(1)) | |||
|
|||
self._register_load_state_dict_pre_hook(self._load_state_dict_hook) | |||
|
|||
def _load_state_dict_hook( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any test case for this, which arguably is the important part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it mostly calls the other functions. But it makes sense to create an unit test for it.
""" | ||
# inverse n_tokens_per_tile = patch_grid_size**2 + 1, where +1 is the cls token | ||
inpt_n_tokens_per_tile, inpt_embed_dim = inpt_pos_embed.shape | ||
inpt_patch_grid_size = int(math.sqrt(inpt_n_tokens_per_tile - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of a nit, but why do you compute it in the method for local pos embeddings and pass it to the method for global pos embeddings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will make it consistent. I dont think that there is a reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
took a closer look, in both methods i am passing "tgt_patch_grid_size"
inpt_local_pos_embed = self._resize_local_position_embedding(
local_pos_embed=inpt_local_pos_embed,
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)
inpt_global_pos_embed = self._resize_global_position_embedding(
global_pos_embed=inpt_global_pos_embed,
tgt_max_num_tiles=tgt_max_num_tiles_x,
tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
)
pos_embed = pos_embed.reshape( | ||
max_num_tiles_x, | ||
max_num_tiles_y, | ||
inpt_patch_grid_size, | ||
inpt_patch_grid_size, | ||
embed_dim, | ||
) | ||
|
||
# combine max_num_tiles and patch_grid_size into one dimension | ||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() | ||
pos_embed = pos_embed.reshape( | ||
max_num_tiles_x * inpt_patch_grid_size, | ||
max_num_tiles_y * inpt_patch_grid_size, | ||
embed_dim, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need all 3 of these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think so. I dont see other way to get the same shape/order. I will ask metamate :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Please add the TODOs from the comments and an additional one to support resizing for the none-tiled positional embedding.
@@ -100,23 +100,322 @@ def __init__( | |||
|
|||
self.gate = nn.Parameter(torch.zeros(1)) | |||
|
|||
self._register_load_state_dict_pre_hook(self._load_state_dict_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a TODO here to switch to the public method after 2.5 is stable
@@ -166,16 +464,127 @@ def __init__( | |||
) | |||
self.gate = nn.Parameter(torch.zeros(1)) | |||
|
|||
# Register load hook to interpolate positional embeddings | |||
self._register_load_state_dict_pre_hook(self._load_state_dict_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment for TODO
Co-authored-by: Felipe Mello <felipemello@fb.com>
Context
What is the purpose of this PR? Is it to
When loading state dict for image models, the positional embeddings may have different shape. This PR allows the reshaping of the embeddings to match the target desired shape that the model was initialized with.
TLDR of the steps:
Unit tests and docstrings should help understand the numbers
Changelog
added resizing to match fairs implementation here:
Test plan