Skip to content

Commit

Permalink
[ONNX] Export ROIAlign with aligned=True (#2613)
Browse files Browse the repository at this point in the history
* Add support for export ROIAlign

* Fix for feedback

* flake8
  • Loading branch information
neginraoof authored Aug 27, 2020
1 parent 6f02821 commit 279fca5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
29 changes: 29 additions & 0 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def to_numpy(tensor):
# compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs)

for i in range(0, len(outputs)):
try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
Expand Down Expand Up @@ -121,6 +122,34 @@ def test_roi_align(self):
model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, [(x, single_roi)])

def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])

@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
def test_roi_align_malformed_boxes(self):
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
self.run_model(model, [(x, single_roi)])

def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
Expand Down
8 changes: 6 additions & 2 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import torch
import warnings

_onnx_opset_version = 11

Expand All @@ -20,11 +21,14 @@ def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):

@parse_args('v', 'v', 'f', 'i', 'i', 'i', 'i')
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
if(aligned):
raise RuntimeError('Unsupported: ONNX export of roi_align with aligned')
batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant',
value_t=torch.tensor([0], dtype=torch.long))), 1), False)
rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
if aligned:
warnings.warn("ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes,"
" ONNX forces ROIs to be 1x1 or larger.")
scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float)
rois = g.op("Sub", rois, scale)
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale,
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)

Expand Down

0 comments on commit 279fca5

Please sign in to comment.