Skip to content

Commit

Permalink
[RELAY][MXNET][FRONTEND] add support for MXNET numpy operators (#6054)
Browse files Browse the repository at this point in the history
* [RELAY][MXNET][FRONTEND] add supports for OPs in numpy from mxnet

* Update test_forward.py

* Update mxnet.py

* Update mxnet.py

* Update test_forward.py

* update and bugfix

* test for multiple dtypes

* Update test_forward.py

* add data type and optimize coding style

* replace pytest.skip with @pytest.mark.skipif

* Update test_forward.py

* update pytest style

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-169.ap-northeast-1.compute.internal>
  • Loading branch information
sandyhu533 and Ubuntu authored Aug 21, 2020
1 parent 061bb01 commit 4c728d5
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 77 deletions.
105 changes: 105 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,93 @@ def impl(inputs, input_types):
return impl


def _mx_npi_transpose(inputs, attrs):
axes = attrs.get_int_tuple("axes", None)
# translate default case
axes = None if len(axes) == 0 or axes[0] is None else axes
return _op.transpose(inputs[0], axes=axes)


def _mx_npi_pad(inputs, attrs):
pad_mode = attrs.get_str('mode', None)
if pad_mode is None:
raise tvm.error.OpAttributeRequired(
'Attribute "mode" not found in operator pad.')
if pad_mode not in ['constant', 'edge', 'reflect']:
raise tvm.error.OpAttributeInvalid(
'Value ' + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple('pad_width', None)
if pad_width is None:
raise tvm.error.OpAttributeRequired(
'Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.')
constant_values = attrs.get_float('constant_values', 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))

return _op.nn.pad(data=inputs[0],
pad_width=padding,
pad_value=constant_values,
pad_mode=pad_mode)


def _mx_npi_concatenate(inputs, attrs):
axis = attrs.get_str("axis", "0")
if axis == "None":
return _op.reshape(_op.concatenate(tuple(inputs), axis=0), (-1,))
else:
return _op.concatenate(tuple(inputs), axis=int(axis))


def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
shape_list = list(shape)
new_shape_list = []
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
else:
raise tvm.error.OpAttributeInvalid(
'Shape dimension %d is not supported' % num)
shape = tuple(new_shape_list)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)


def _mx_split_v2(inputs, attrs):
axis = attrs.get_int("axis")
indices = list(attrs.get_int_tuple("indices", []))
# remove the prefix '0'
if len(indices) != 0 and indices[0] == 0:
indices.remove(0)
sections = attrs.get_int("sections", 0)
indices_or_sections = list(indices) if len(indices) != 0 else sections
res = _op.split(inputs[0], indices_or_sections=indices_or_sections, axis=axis)
if attrs.get_bool("squeeze_axis", False):
res = tuple([_op.squeeze(x, axis=[axis]) for x in res])
return res


def _mx_npi_where_rscalar(inputs, attrs):
scalar = attrs.get_float("scalar")
dtype = _infer_type(inputs[1]).checked_type.dtype
scalar = _expr.const(scalar, dtype=dtype)
ones = _op.ones_like(inputs[1])
scalar = _op.multiply(ones, scalar)
return _op.where(inputs[0], inputs[1], scalar)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -2322,6 +2409,7 @@ def impl(inputs, input_types):
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
"_split_v2" : _mx_split_v2,
"SwapAxis" : _mx_swap_axis,
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
Expand Down Expand Up @@ -2400,6 +2488,23 @@ def impl(inputs, input_types):
"_contrib_quantized_pooling": _qnn_pooling,
"_contrib_quantized_batch_norm" : _qnn_batch_norm,
"_sg_mkldnn_fully_connected": _qnn_fully_connected,
# numpy
"_np_transpose" : _mx_npi_transpose,
"_npi_transpose" : _mx_npi_transpose,
"_npi_pad" : _mx_npi_pad,
"_npi_concatenate" : _mx_npi_concatenate,
"_npx_reshape" : _mx_npx_reshape,
"_np_copy" : _rename(_op.copy),
"_npi_power" : _rename(_op.power),
"_npi_power_scalar" : _binop_scalar(_op.power),
"_npi_multiply" : _rename(_op.multiply),
"_npi_multiply_scalar" : _binop_scalar(_op.multiply),
"_npi_add" : _rename(_op.add),
"_npi_add_scalar" : _binop_scalar(_op.add),
"_npi_where_rscalar" : _mx_npi_where_rscalar,
"_npi_less" : _rename(_op.less),
"_npi_tanh" : _rename(_op.tanh),
"_npi_true_divide_scalar" : _binop_scalar(_op.divide),
}

# set identity list
Expand Down
Loading

0 comments on commit 4c728d5

Please sign in to comment.