Skip to content

Commit

Permalink
[Frontend][MXNet] add _npi_stack, issue apache#7186
Browse files Browse the repository at this point in the history
  • Loading branch information
insop committed Jan 5, 2021
1 parent d052752 commit faeebb0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,6 +2335,14 @@ def _mx_npi_concatenate(inputs, attrs):
return _op.concatenate(tuple(inputs), axis=int(axis))


def _mx_npi_stack(inputs, attrs):
axis = attrs.get_str("axis", "0")
if axis == "None":
return _op.reshape(_op.stack(tuple(inputs), axis=0), (-1,))
else:
return _op.stack(tuple(inputs), axis=int(axis))


def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
Expand Down Expand Up @@ -2700,6 +2708,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_less_equal": _mx_compare(_op.less_equal, _rename),
"_npi_tanh": _rename(_op.tanh),
"_npi_true_divide_scalar": _binop_scalar(_op.divide),
"_npi_stack": _mx_npi_stack,
}

# set identity list
Expand Down
28 changes: 28 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,34 @@ def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype, target,
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)


@pytest.mark.parametrize(
"data_shape1, data_shape2, axis",
[
((3,), (3,), 0),
((3,), (3,), -1),
((1, 3, 2), (1, 3, 2), 2),
((1, 3, 3), (1, 3, 3), 1),
((1, 3), (1, 3), 0),
],
)
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_stack(data_shape1, data_shape2, axis, dtype, target, ctx, kind):
data_np1 = np.random.uniform(size=data_shape1).astype(dtype)
data_np2 = np.random.uniform(size=data_shape2).astype(dtype)
data1 = mx.sym.var("data1")
data2 = mx.sym.var("data2")
ref_res = mx.np.stack([mx.np.array(data_np1), mx.np.array(data_np2)], axis=axis)
mx_sym = mx.sym.np.stack([data1.as_np_ndarray(), data2.as_np_ndarray()], axis=axis)
mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype
)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np1, data_np2)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)


@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2, 3, 1), (1, 8)])
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "bool"])
@tvm.testing.parametrize_targets
Expand Down

0 comments on commit faeebb0

Please sign in to comment.