Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add ONNX export support to torch.roll #1194

Merged
merged 5 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion mmcv/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def topk(g, self, k, dim, largest, sorted, out=None):


def masked_select(g, self, mask):
from torch.onnx.symbolic_opset9 import nonzero, expand_as
from torch.onnx.symbolic_opset9 import expand_as, nonzero
index = nonzero(g, expand_as(g, mask, self))
return g.op('GatherND', self, index)

Expand Down Expand Up @@ -406,6 +406,55 @@ def cummin(g, input, dim):
return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)


@parse_args('v', 'v', 'is')
def roll(g, input, shifts, dims):
from torch.onnx.symbolic_opset9 import squeeze
input_shape = g.op('Shape', input)

need_flatten = len(dims) == 0
grimoire marked this conversation as resolved.
Show resolved Hide resolved
if need_flatten:
resize_shape = input_shape
input = g.op('Reshape', input,
g.op('Constant', value_t=torch.LongTensor([1, -1])))
input_shape = g.op('Shape', input)
dims = [1]

for index, dim in enumerate(dims):
end_size = sym_help._slice_helper(
g, input_shape, axes=[0], ends=[dim + 1], starts=[dim])
shift_size = sym_help._slice_helper(
g, shifts, axes=[0], ends=[index + 1], starts=[index])
slice_size = g.op('Sub', end_size, shift_size)

# Can not use Mod because tensorrt does not support
div_size = g.op('Div', slice_size, end_size)
slice_size = g.op('Sub', slice_size, g.op('Mul', end_size, div_size))

end_size = squeeze(g, end_size)
slice_size = squeeze(g, slice_size)
input_slice0 = sym_help._slice_helper(
g,
input,
axes=dim,
starts=torch.LongTensor([0]),
ends=slice_size,
dynamic_slice=True)
input_slice1 = sym_help._slice_helper(
g,
input,
axes=dim,
ends=end_size,
starts=slice_size,
dynamic_slice=True)

input = g.op('Concat', input_slice1, input_slice0, axis_i=dim)

if need_flatten:
input = g.op('Reshape', input, resize_shape)

return input


def register_extra_symbolics(opset=11):
register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset)
Expand Down Expand Up @@ -433,3 +482,4 @@ def register_extra_symbolics(opset=11):
register_op('grid_sampler', grid_sampler, '', opset)
register_op('cummax', cummax, '', opset)
register_op('cummin', cummin, '', opset)
register_op('roll', roll, '', opset)
71 changes: 61 additions & 10 deletions tests/test_ops/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
onnx_file = 'tmp.onnx'


@pytest.fixture(autouse=True)
def clear_tmpfile_after_test():
grimoire marked this conversation as resolved.
Show resolved Hide resolved
# clear onnx_file before test
if os.path.exists(onnx_file):
os.remove(onnx_file)

yield

# clear onnx_file after test
if os.path.exists(onnx_file):
os.remove(onnx_file)


class WrapFunction(nn.Module):

def __init__(self, wrapped_function):
Expand Down Expand Up @@ -56,7 +69,6 @@ def process_grid_sample(func, input, grid, ort_custom_op_path=''):
'grid': grid.detach().numpy()
})
pytorch_results = wrapped_model(input.clone(), grid.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-3)


Expand Down Expand Up @@ -149,7 +161,6 @@ def test_nms():
'boxes': boxes.detach().numpy()
})
onnx_score = onnx_dets[:, 4]
os.remove(onnx_file)
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)


Expand Down Expand Up @@ -225,7 +236,7 @@ def test_softnms():
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
os.remove(onnx_file)

assert np.allclose(pytorch_dets, onnx_dets, atol=1e-3)
assert np.allclose(onnx_inds, onnx_inds, atol=1e-3)

Expand Down Expand Up @@ -299,7 +310,7 @@ def warpped_function(torch_input, torch_rois):
onnx_output = onnx_output[0]

# allclose
os.remove(onnx_file)

assert np.allclose(pytorch_output, onnx_output, atol=1e-3)


Expand Down Expand Up @@ -378,7 +389,7 @@ def warpped_function(torch_input, torch_rois):
onnx_output = onnx_output[0]

# allclose
os.remove(onnx_file)

assert np.allclose(pytorch_output, onnx_output, atol=1e-3)


Expand Down Expand Up @@ -443,7 +454,6 @@ def warpped_function(torch_input, torch_rois):
onnx_output = onnx_output[0]

# allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)


Expand All @@ -468,8 +478,7 @@ def func(feat, scale_factor=2):
sess = rt.InferenceSession(onnx_file)
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
pytorch_result = func(dummy_input).detach().numpy()
if os.path.exists(onnx_file):
os.remove(onnx_file)

assert np.allclose(pytorch_result, onnx_result, atol=1e-3)


Expand Down Expand Up @@ -515,7 +524,7 @@ def corner_pool_func(input):
sess = rt.InferenceSession(onnx_file, session_options)
ort_result = sess.run(None, {'input': input.detach().numpy()})
pytorch_results = wrapped_model(input.clone())
os.remove(onnx_file)

assert np.allclose(pytorch_results, ort_result, atol=1e-5)


Expand Down Expand Up @@ -591,4 +600,46 @@ def test_cummax_cummin(key, opset=11):
pytorch_inds = pytorch_inds.detach().numpy()
assert np.allclose(pytorch_output, ort_output, atol=1e-5)
assert np.all(pytorch_inds == ort_inds)
os.remove(onnx_file)


@pytest.mark.parametrize('shifts_dims_pair', [([-3, 5], [2, 0]), (5, None)])
def test_roll(shifts_dims_pair):
opset = 11
from mmcv.onnx.symbolic import register_extra_symbolics
register_extra_symbolics(opset)

from mmcv.ops import get_onnxruntime_op_path
grimoire marked this conversation as resolved.
Show resolved Hide resolved
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')

input = torch.arange(0, 4 * 5 * 6, dtype=torch.float32).view(4, 5, 6)

shifts, dims = shifts_dims_pair
func = partial(torch.roll, shifts=shifts, dims=dims)
wrapped_model = WrapFunction(func).eval()

with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)

onnx_model = onnx.load(onnx_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)

sess = rt.InferenceSession(onnx_file)
ort_output = sess.run(None, {'input': input.detach().numpy()})[0]

with torch.no_grad():
pytorch_output = wrapped_model(input.clone())

torch.testing.assert_allclose(ort_output, pytorch_output)