Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resizable image positional embeddings #1695

Merged
merged 13 commits into from
Oct 1, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Sep 26, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

When loading state dict for image models, the positional embeddings may have different shape. This PR allows the reshaping of the embeddings to match the target desired shape that the model was initialized with.

TLDR of the steps:

  1. permute shapes of the input embedding, so that the shapes you are permuting are the last two in the tensor
  2. pass to F.interpolate the shape of your instantiated embedding
  3. now your input embedding have the same shape as the instantiated one

Unit tests and docstrings should help understand the numbers

Changelog

added resizing to match fairs implementation here:

  1. https://github.com/fairinternal/internal-llama-models/blob/c14773ebe252064fa88d7f613d3bee1757480ae7/models/llama3/reference_impl/multimodal/model.py#L411
  2. https://github.com/fairinternal/internal-llama-models/blob/c14773ebe252064fa88d7f613d3bee1757480ae7/models/llama3/reference_impl/multimodal/model.py#L829
  3. added rst for clip

Test plan

  • New unit tests
  • Checkpoint loads for instruct and pretraining
  • NOT ABLE TO ACTUALLY RUN AND SEE LOSS, GIVEN MY SDPA ISSUE

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2024
Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1695

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 59f1996 with merge base 4e69db8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -0,0 +1,104 @@
# Config for single device full finetuning in full_finetune_single_device.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of adding this? Just so that we support the non-instruct version of the model? I'm a bit confused cause I thought one big diff with instruct-tuned vs not is the extra trainable special tokens on the text size, which this PR doesn't address

Copy link
Contributor Author

@felipemello1 felipemello1 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just for testing. I will remove it before the PR is ready

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding the special token, thats a question for @pbontrager . I am not sure.

@felipemello1 felipemello1 marked this pull request as ready for review September 30, 2024 18:34
@felipemello1 felipemello1 changed the title [WIP] Resizable image positional embeddings Resizable image positional embeddings Sep 30, 2024
tile_pos_emb_test_cases = [
{
"tgt_num_tiles": 1,
# [max_num_tiles, max_num_tiles, -1, embed_dim] -> (2, 2, 2, 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't fully follow these comments in each test case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just trying to help the reader understand the dimensions provided. I can make it better. -1 is because the actual pos embedding has dim=1 there, but when i created the tests, i created with 2.

@@ -100,23 +100,314 @@ def __init__(

self.gate = nn.Parameter(torch.zeros(1))

self._register_load_state_dict_pre_hook(self._load_state_dict_hook)

def _load_state_dict_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any test case for this, which arguably is the important part

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it mostly calls the other functions. But it makes sense to create an unit test for it.

torchtune/models/clip/_position_embeddings.py Outdated Show resolved Hide resolved
torchtune/models/clip/_position_embeddings.py Outdated Show resolved Hide resolved
torchtune/models/clip/_position_embeddings.py Outdated Show resolved Hide resolved
"""
# inverse n_tokens_per_tile = patch_grid_size**2 + 1, where +1 is the cls token
inpt_n_tokens_per_tile, inpt_embed_dim = inpt_pos_embed.shape
inpt_patch_grid_size = int(math.sqrt(inpt_n_tokens_per_tile - 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of a nit, but why do you compute it in the method for local pos embeddings and pass it to the method for global pos embeddings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will make it consistent. I dont think that there is a reason

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took a closer look, in both methods i am passing "tgt_patch_grid_size"

inpt_local_pos_embed = self._resize_local_position_embedding(
                local_pos_embed=inpt_local_pos_embed,
                tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
            )

inpt_global_pos_embed = self._resize_global_position_embedding(
                global_pos_embed=inpt_global_pos_embed,
                tgt_max_num_tiles=tgt_max_num_tiles_x,
                tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
            )

torchtune/models/clip/_position_embeddings.py Show resolved Hide resolved
torchtune/models/clip/_position_embeddings.py Outdated Show resolved Hide resolved
Comment on lines +326 to +340
pos_embed = pos_embed.reshape(
max_num_tiles_x,
max_num_tiles_y,
inpt_patch_grid_size,
inpt_patch_grid_size,
embed_dim,
)

# combine max_num_tiles and patch_grid_size into one dimension
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.reshape(
max_num_tiles_x * inpt_patch_grid_size,
max_num_tiles_y * inpt_patch_grid_size,
embed_dim,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need all 3 of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so. I dont see other way to get the same shape/order. I will ask metamate :P

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Please add the TODOs from the comments and an additional one to support resizing for the none-tiled positional embedding.

@@ -100,23 +100,322 @@ def __init__(

self.gate = nn.Parameter(torch.zeros(1))

self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO here to switch to the public method after 2.5 is stable

@@ -166,16 +464,127 @@ def __init__(
)
self.gate = nn.Parameter(torch.zeros(1))

# Register load hook to interpolate positional embeddings
self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for TODO

@felipemello1 felipemello1 merged commit 55b4814 into pytorch:main Oct 1, 2024
14 checks passed
@felipemello1 felipemello1 deleted the reshape_pos_emb branch October 1, 2024 18:01
RdoubleA pushed a commit that referenced this pull request Oct 2, 2024
Co-authored-by: Felipe Mello <felipemello@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants