Skip to content

Commit

Permalink
Update modeling_altclip.py by adding interpolate_pos_encoding
Browse files Browse the repository at this point in the history
`interpolate_pos_encoding` function to the `altclip` vision models. It allows for high resolution images into the model for finetunning irrespective of the pre-trained image configuration.

issue huggingface#30579
  • Loading branch information
bhuvanmdev authored May 2, 2024
1 parent 5995299 commit ad208ea
Showing 1 changed file with 54 additions and 4 deletions.
58 changes: 54 additions & 4 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings. Defaults to `False`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -139,6 +141,8 @@
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings. Defaults to `False`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -1013,15 +1017,54 @@ def __init__(self, config: AltCLIPVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

num_patches = embeddings.shape[1] - 1
temp_pos_embed = self.position_embedding(self.position_ids)
num_positions = temp_pos_embed.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)

class_pos_embed = temp_pos_embed[:, 0,:]
patch_pos_embed = temp_pos_embed[:, 1:,:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor,interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, pixel_values.shape[-2], pixel_values.shape[-1])
else:
embeddings = embeddings + self.position_embedding(self.position_ids)

return embeddings


Expand Down Expand Up @@ -1102,6 +1145,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -1116,7 +1160,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values,interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

encoder_outputs = self.encoder(
Expand Down Expand Up @@ -1162,6 +1206,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand Down Expand Up @@ -1192,6 +1237,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)


Expand Down Expand Up @@ -1552,6 +1598,7 @@ def get_image_features(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> torch.FloatTensor:
r"""
Returns:
Expand Down Expand Up @@ -1584,6 +1631,7 @@ def get_image_features(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

pooled_output = vision_outputs[1] # pooled_output
Expand All @@ -1604,6 +1652,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, AltCLIPOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1648,6 +1697,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[1]
Expand Down

0 comments on commit ad208ea

Please sign in to comment.