Skip to content

Commit

Permalink
🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision …
Browse files Browse the repository at this point in the history
…models (huggingface#33226)

* Fix `torch.jit.tracing` for `interpolate_pos_encoding` in all vision models

* Apply formatting

* Add missing `self.config = config`

* Fix copies

* Fix hiera interpolation unit test

* Formatting

* Update `_import_structure`

* make style

* Fix docstring

* Use `# Copied from` instead of utils

* DeiT variable renaming (`class_and_dist_pos_embed`)

* Fix Hiera `interpolate_pos_encoding`
  • Loading branch information
xenova authored and dataKim1201 committed Oct 7, 2024
1 parent cc7ccd5 commit c20948e
Show file tree
Hide file tree
Showing 26 changed files with 558 additions and 369 deletions.
40 changes: 23 additions & 17 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_beit import BeitConfig
Expand Down Expand Up @@ -150,41 +151,46 @@ def __init__(self, config: BeitConfig) -> None:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(
self,
Expand Down Expand Up @@ -566,7 +572,7 @@ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

Expand Down
44 changes: 26 additions & 18 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
"""PyTorch BLIP model."""

import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
Expand All @@ -33,6 +32,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
Expand Down Expand Up @@ -232,38 +232,46 @@ def __init__(self, config: BlipVisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
Expand Down
43 changes: 26 additions & 17 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
Expand Down Expand Up @@ -198,38 +199,46 @@ def __init__(self, config: Blip2VisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
Expand Down
40 changes: 23 additions & 17 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_data2vec_vision import Data2VecVisionConfig

Expand Down Expand Up @@ -149,41 +150,46 @@ def __init__(self, config: Data2VecVisionConfig) -> None:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(
self,
Expand Down Expand Up @@ -575,7 +581,7 @@ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

Expand Down
39 changes: 22 additions & 17 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_deit import DeiTConfig

Expand Down Expand Up @@ -77,39 +78,43 @@ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

# return self.position_embeddings
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2

if num_patches == num_positions and height == width:
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
class_and_dist_pos_embed = self.position_embeddings[:, :2]
patch_pos_embed = self.position_embeddings[:, 2:]

dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)

def forward(
self,
Expand Down
Loading

0 comments on commit c20948e

Please sign in to comment.