From 06b5d227bb5a8b35246f46b151cfda0d57e5cef8 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Fri, 24 Jul 2020 14:22:42 -0700 Subject: [PATCH] ONNX import: use Conv pad attribute for symmetrical padding (#18675) Signed-off-by: Serge Panev --- .../contrib/onnx/onnx2mx/_op_translations.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 60ca44df387f..1bf60a02160b 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -289,14 +289,25 @@ def conv(attrs, inputs, proto_obj): no_bias = new_attrs['no_bias'] if 'no_bias' in new_attrs else 0 bias = None if no_bias is True else inputs[2] - # Unlike ONNX, MXNet's convolution operator does not support asymmetric padding, so we first - # use 'Pad' operator, which supports asymmetric padding. Then use the convolution operator. - pad_width = (0, 0, 0, 0) + translation_utils._pad_sequence_fix(padding, kernel_dim=len(kernel)) - pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width) - - conv_op = symbol.Convolution(pad_op, inputs[1], bias, - kernel=kernel, stride=stride, dilate=dilations, - num_filter=num_filter, num_group=num_group, no_bias=no_bias) + mxnet_pad = translation_utils._pad_sequence_fix(padding, kernel_dim=len(kernel)) + + left_pads = mxnet_pad[0::2] + right_pads = mxnet_pad[1::2] + is_pad_sym = left_pads == right_pads + + if not is_pad_sym: + # Unlike ONNX, MXNet's convolution operator does not support asymmetric padding, so we first + # use 'Pad' operator, which supports asymmetric padding. Then use the convolution operator. + pad_width = (0, 0, 0, 0) + mxnet_pad + pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width) + conv_op = symbol.Convolution(pad_op, inputs[1], bias, + kernel=kernel, stride=stride, dilate=dilations, + num_filter=num_filter, num_group=num_group, no_bias=no_bias) + else: + pad_width = left_pads + conv_op = symbol.Convolution(inputs[0], inputs[1], bias, + kernel=kernel, stride=stride, dilate=dilations, pad=pad_width, + num_filter=num_filter, num_group=num_group, no_bias=no_bias) return conv_op, new_attrs, inputs