Skip to content

Commit

Permalink
Try remove eager scripting calls (pytorch#2248)
Browse files Browse the repository at this point in the history
* Try remove eager scripting calls

* remove script call

Co-authored-by: eellison <eellison@fb.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
3 people committed Jun 29, 2020
1 parent 6631b74 commit f539313
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __call__(self, matched_idxs):
return pos_idx, neg_idx


@torch.jit.script
@torch.jit._script_if_tracing
def encode_boxes(reference_boxes, proposals, weights):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
return xy_preds_i, end_scores_i


@torch.jit.script
@torch.jit._script_if_tracing
def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil,
widths, heights, offset_x, offset_y, num_keypoints):
xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
Expand Down Expand Up @@ -451,7 +451,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
return im_mask


@torch.jit.script
@torch.jit._script_if_tracing
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
res_append = torch.zeros(0, im_h, im_w)
for i in range(masks.size(0)):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torchvision


@torch.jit.script
def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float) -> Tensor
"""
Expand Down Expand Up @@ -41,7 +40,7 @@ def nms(boxes, scores, iou_threshold):
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)


@torch.jit.script
@torch.jit._script_if_tracing
def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
"""
Expand Down

0 comments on commit f539313

Please sign in to comment.