Skip to content

Commit

Permalink
Fix ONNX exports for Optimum compatible models (#31311)
Browse files Browse the repository at this point in the history
* fixed models

* format with bumped ruff version on my local

* fix copies

* add tracing checks

* format

* Update src/transformers/utils/generic.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* format

* style fix

* Update modeling_mobilevit.py

* add docstring and change name

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
merveenoyan and amyeroberts authored Jun 27, 2024
1 parent dc76e9f commit c9f191a
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 17 deletions.
7 changes: 5 additions & 2 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig

Expand Down Expand Up @@ -590,8 +591,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
self.shift_size = torch_int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)

def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
torch_int,
)
from .configuration_donut_swin import DonutSwinConfig

Expand Down Expand Up @@ -562,8 +563,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
self.shift_size = torch_int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)

def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging
from ...utils import ModelOutput, logging, torch_int
from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig

Expand Down Expand Up @@ -226,7 +226,7 @@ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_ind
posemb_tok = posemb[:, :start_index]
posemb_grid = posemb[0, start_index:]

old_grid_size = int(math.sqrt(len(posemb_grid)))
old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)

posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/models/imagegpt/modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_float,
)
from .configuration_imagegpt import ImageGPTConfig


Expand Down Expand Up @@ -229,7 +235,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_layoutlmv3 import LayoutLMv3Config


Expand Down Expand Up @@ -910,8 +916,8 @@ def forward(
patch_height = patch_width = None
if pixel_values is not None:
patch_height, patch_width = (
int(pixel_values.shape[2] / self.config.patch_size),
int(pixel_values.shape[3] / self.config.patch_size),
torch_int(pixel_values.shape[2] / self.config.patch_size),
torch_int(pixel_values.shape[3] / self.config.patch_size),
)
visual_embeddings = self.forward_image(pixel_values)
visual_attention_mask = torch.ones(
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/mobilevit/modeling_mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_mobilevit import MobileViTConfig

Expand Down Expand Up @@ -437,8 +438,16 @@ def unfolding(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict]:

batch_size, channels, orig_height, orig_width = features.shape

new_height = int(math.ceil(orig_height / patch_height) * patch_height)
new_width = int(math.ceil(orig_width / patch_width) * patch_width)
new_height = (
torch_int(torch.ceil(orig_height / patch_height) * patch_height)
if torch.jit.is_tracing()
else int(math.ceil(orig_height / patch_height) * patch_height)
)
new_width = (
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
if torch.jit.is_tracing()
else int(math.ceil(orig_width / patch_width) * patch_width)
)

interpolate = False
if new_width != orig_width or new_height != orig_height:
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""PyTorch SAM model."""

import collections
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -232,7 +231,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarit
# SamAttention
_, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1)

if attention_similarity is not None:
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_swin import SwinConfig
Expand Down Expand Up @@ -639,8 +640,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
self.shift_size = torch_int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)

def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
tensor_size,
to_numpy,
to_py_obj,
torch_float,
torch_int,
transpose,
working_or_temp_dir,
)
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,30 @@ def infer_framework(model_class):
raise TypeError(f"Could not infer framework from class {model_class}.")


def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not is_torch_available():
return int(x)

import torch

return x.to(torch.int64) if torch.jit.is_tracing() else int(x)


def torch_float(x):
"""
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
"""
if not is_torch_available():
return int(x)

import torch

return x.to(torch.float32) if torch.jit.is_tracing() else int(x)


def filter_out_non_signature_kwargs(extra: Optional[list] = None):
"""
Decorator to filter out named arguments that are not in the function signature.
Expand Down

0 comments on commit c9f191a

Please sign in to comment.