Skip to content

Commit

Permalink
[fbsync] [ViT] Support fine-tuning with different image resolution (#…
Browse files Browse the repository at this point in the history
…5025)

Summary:
* add from_checkpoint method for vit

* remove useless change

* Making interpolate_embeddings a utility function

* remove logging

* fix type hint

* fix return type check

* ad  retuurns in docsting & unify type hint

* remove useless import

* fix issue: 'type' object is not subscriptable

* Fixing typing issues

* Making interpolation mode configurable

* formatting

Reviewed By: prabhat00155

Differential Revision: D33253466

fbshipit-source-id: 79bf6855f2dcee3c2fef6c05c243a0dc8dfee25e

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Dec 21, 2021
1 parent 66a5b76 commit a08a28e
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions torchvision/prototype/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,78 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
progress=progress,
**kwargs,
)


def interpolate_embeddings(
image_size: int,
patch_size: int,
model_state: "OrderedDict[str, torch.Tensor]",
interpolation_mode: str = "bicubic",
reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
"""This function helps interpolating positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
Args:
image_size (int): Image size of the new model.
patch_size (int): Patch size of the new model.
model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
reset_heads (bool): If true, not copying the state of heads. Default: False.
Returns:
OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
"""
# Shape of pos_embedding is (1, seq_length, hidden_dim)
pos_embedding = model_state["encoder.pos_embedding"]
n, seq_length, hidden_dim = pos_embedding.shape
if n != 1:
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")

new_seq_length = (image_size // patch_size) ** 2 + 1

# Need to interpolate the weights for the position embedding.
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if new_seq_length != seq_length:
# The class token embedding shouldn't be interpolated so we split it up.
seq_length -= 1
new_seq_length -= 1
pos_embedding_token = pos_embedding[:, :1, :]
pos_embedding_img = pos_embedding[:, 1:, :]

# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
seq_length_1d = int(math.sqrt(seq_length))
torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")

# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
new_seq_length_1d = image_size // patch_size

# Perform interpolation.
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
new_pos_embedding_img = nn.functional.interpolate(
pos_embedding_img,
size=new_seq_length_1d,
mode=interpolation_mode,
align_corners=True,
)

# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)

# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)

model_state["encoder.pos_embedding"] = new_pos_embedding

if reset_heads:
model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
for k, v in model_state.items():
if not k.startswith("heads"):
model_state_copy[k] = v
model_state = model_state_copy

return model_state

0 comments on commit a08a28e

Please sign in to comment.