From f539313012debae786896623c897cd09d34a1659 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 26 Jun 2020 06:27:47 -0700 Subject: [PATCH] Try remove eager scripting calls (#2248) * Try remove eager scripting calls * remove script call Co-authored-by: eellison Co-authored-by: Francisco Massa --- torchvision/models/detection/_utils.py | 2 +- torchvision/models/detection/roi_heads.py | 4 ++-- torchvision/ops/boxes.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 4b65ffa4a4e..3595114f24d 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -75,7 +75,7 @@ def __call__(self, matched_idxs): return pos_idx, neg_idx -@torch.jit.script +@torch.jit._script_if_tracing def encode_boxes(reference_boxes, proposals, weights): # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor """ diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 19cc15a8cc0..82ba6e8b5c0 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -205,7 +205,7 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, return xy_preds_i, end_scores_i -@torch.jit.script +@torch.jit._script_if_tracing def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints): xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) @@ -451,7 +451,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): return im_mask -@torch.jit.script +@torch.jit._script_if_tracing def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): res_append = torch.zeros(0, im_h, im_w) for i in range(masks.size(0)): diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index e7442f57352..c7d74db4500 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -4,7 +4,6 @@ import torchvision -@torch.jit.script def nms(boxes, scores, iou_threshold): # type: (Tensor, Tensor, float) -> Tensor """ @@ -41,7 +40,7 @@ def nms(boxes, scores, iou_threshold): return torch.ops.torchvision.nms(boxes, scores, iou_threshold) -@torch.jit.script +@torch.jit._script_if_tracing def batched_nms(boxes, scores, idxs, iou_threshold): # type: (Tensor, Tensor, Tensor, float) -> Tensor """