Skip to content

Commit

Permalink
Fix for roi_align export (#1988)
Browse files Browse the repository at this point in the history
  • Loading branch information
neginraoof authored Mar 16, 2020
1 parent 2875315 commit e96f2d5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
nms_out = g.op('NonMaxSuppression', boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), 1)

@parse_args('v', 'v', 'f', 'i', 'i', 'i')
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
@parse_args('v', 'v', 'f', 'i', 'i', 'i', 'i')
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
if(aligned):
raise RuntimeError('Unsupported: ONNX export of roi_align with aligned')
batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant',
value_t=torch.tensor([0], dtype=torch.long))), 1), False)
rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
Expand Down

0 comments on commit e96f2d5

Please sign in to comment.