Skip to content

Commit

Permalink
ONNX import: use Conv pad attribute for symmetrical padding (apache#1…
Browse files Browse the repository at this point in the history
…8675)

Signed-off-by: Serge Panev <spanev@nvidia.com>
  • Loading branch information
Kh4L committed Jul 24, 2020
1 parent e31ad77 commit 06b5d22
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,25 @@ def conv(attrs, inputs, proto_obj):
no_bias = new_attrs['no_bias'] if 'no_bias' in new_attrs else 0
bias = None if no_bias is True else inputs[2]

# Unlike ONNX, MXNet's convolution operator does not support asymmetric padding, so we first
# use 'Pad' operator, which supports asymmetric padding. Then use the convolution operator.
pad_width = (0, 0, 0, 0) + translation_utils._pad_sequence_fix(padding, kernel_dim=len(kernel))
pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)

conv_op = symbol.Convolution(pad_op, inputs[1], bias,
kernel=kernel, stride=stride, dilate=dilations,
num_filter=num_filter, num_group=num_group, no_bias=no_bias)
mxnet_pad = translation_utils._pad_sequence_fix(padding, kernel_dim=len(kernel))

left_pads = mxnet_pad[0::2]
right_pads = mxnet_pad[1::2]
is_pad_sym = left_pads == right_pads

if not is_pad_sym:
# Unlike ONNX, MXNet's convolution operator does not support asymmetric padding, so we first
# use 'Pad' operator, which supports asymmetric padding. Then use the convolution operator.
pad_width = (0, 0, 0, 0) + mxnet_pad
pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
conv_op = symbol.Convolution(pad_op, inputs[1], bias,
kernel=kernel, stride=stride, dilate=dilations,
num_filter=num_filter, num_group=num_group, no_bias=no_bias)
else:
pad_width = left_pads
conv_op = symbol.Convolution(inputs[0], inputs[1], bias,
kernel=kernel, stride=stride, dilate=dilations, pad=pad_width,
num_filter=num_filter, num_group=num_group, no_bias=no_bias)

return conv_op, new_attrs, inputs

Expand Down

0 comments on commit 06b5d22

Please sign in to comment.