From c9f191a0b7f391594901a960671f2a199122ef48 Mon Sep 17 00:00:00 2001 From: Merve Noyan Date: Thu, 27 Jun 2024 12:46:36 +0300 Subject: [PATCH] Fix ONNX exports for Optimum compatible models (#31311) * fixed models * format with bumped ruff version on my local * fix copies * add tracing checks * format * Update src/transformers/utils/generic.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * format * style fix * Update modeling_mobilevit.py * add docstring and change name * Update __init__.py * Update __init__.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/clap/modeling_clap.py | 7 ++++-- .../models/donut/modeling_donut_swin.py | 7 ++++-- src/transformers/models/dpt/modeling_dpt.py | 4 ++-- .../models/imagegpt/modeling_imagegpt.py | 10 ++++++-- .../models/layoutlmv3/modeling_layoutlmv3.py | 12 +++++++--- .../models/mobilevit/modeling_mobilevit.py | 13 ++++++++-- src/transformers/models/sam/modeling_sam.py | 3 +-- src/transformers/models/swin/modeling_swin.py | 7 ++++-- src/transformers/utils/__init__.py | 2 ++ src/transformers/utils/generic.py | 24 +++++++++++++++++++ 10 files changed, 72 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1c236d29d4e734..3e83daa942c022 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -37,6 +37,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig @@ -590,8 +591,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 7e899f453f1c0f..115808a6b11a71 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, logging, + torch_int, ) from .configuration_donut_swin import DonutSwinConfig @@ -562,8 +563,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index a7e554742f2de2..db5db0eae1189b 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -39,7 +39,7 @@ from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, logging +from ...utils import ModelOutput, logging, torch_int from ...utils.backbone_utils import load_backbone from .configuration_dpt import DPTConfig @@ -226,7 +226,7 @@ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_ind posemb_tok = posemb[:, :start_index] posemb_grid = posemb[0, start_index:] - old_grid_size = int(math.sqrt(len(posemb_grid))) + old_grid_size = torch_int(posemb_grid.size(0) ** 0.5) posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index c0b0a83c24d66f..5d59a4ed90e4c9 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -33,7 +33,13 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_float, +) from .configuration_imagegpt import ImageGPTConfig @@ -229,7 +235,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 941ff860042adf..629490350c7dc3 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -33,7 +33,13 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_layoutlmv3 import LayoutLMv3Config @@ -910,8 +916,8 @@ def forward( patch_height = patch_width = None if pixel_values is not None: patch_height, patch_width = ( - int(pixel_values.shape[2] / self.config.patch_size), - int(pixel_values.shape[3] / self.config.patch_size), + torch_int(pixel_values.shape[2] / self.config.patch_size), + torch_int(pixel_values.shape[3] / self.config.patch_size), ) visual_embeddings = self.forward_image(pixel_values) visual_attention_mask = torch.ones( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 551b4ee734b511..59c191b3789641 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -39,6 +39,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_mobilevit import MobileViTConfig @@ -437,8 +438,16 @@ def unfolding(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict]: batch_size, channels, orig_height, orig_width = features.shape - new_height = int(math.ceil(orig_height / patch_height) * patch_height) - new_width = int(math.ceil(orig_width / patch_width) * patch_width) + new_height = ( + torch_int(torch.ceil(orig_height / patch_height) * patch_height) + if torch.jit.is_tracing() + else int(math.ceil(orig_height / patch_height) * patch_height) + ) + new_width = ( + torch_int(torch.ceil(orig_width / patch_width) * patch_width) + if torch.jit.is_tracing() + else int(math.ceil(orig_width / patch_width) * patch_width) + ) interpolate = False if new_width != orig_width or new_height != orig_height: diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index f5baf5bcf3bfd0..c99fb9d7e869f8 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -15,7 +15,6 @@ """PyTorch SAM model.""" import collections -import math from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union @@ -232,7 +231,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarit # SamAttention _, _, _, c_per_head = query.shape attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens - attn = attn / math.sqrt(c_per_head) + attn = attn / (c_per_head**0.5) attn = torch.softmax(attn, dim=-1) if attention_similarity is not None: diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index f3f2dedeb6f3dd..8813d555968880 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -36,6 +36,7 @@ add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ...utils.backbone_utils import BackboneMixin from .configuration_swin import SwinConfig @@ -639,8 +640,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(input_resolution) + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index ce87bc8623132e..ddd329817aff24 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -60,6 +60,8 @@ tensor_size, to_numpy, to_py_obj, + torch_float, + torch_int, transpose, working_or_temp_dir, ) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 4a3c1d970116ae..80232898ce4707 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -753,6 +753,30 @@ def infer_framework(model_class): raise TypeError(f"Could not infer framework from class {model_class}.") +def torch_int(x): + """ + Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int. + """ + if not is_torch_available(): + return int(x) + + import torch + + return x.to(torch.int64) if torch.jit.is_tracing() else int(x) + + +def torch_float(x): + """ + Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float. + """ + if not is_torch_available(): + return int(x) + + import torch + + return x.to(torch.float32) if torch.jit.is_tracing() else int(x) + + def filter_out_non_signature_kwargs(extra: Optional[list] = None): """ Decorator to filter out named arguments that are not in the function signature.