From 86a8504a7ccb956591ffd1a529f625df7d20b520 Mon Sep 17 00:00:00 2001 From: insop Date: Sun, 3 Jan 2021 02:46:55 -0800 Subject: [PATCH] [Frontend][MXNet] add _npi_subtract_scalar (#7191) * [Frontend][MXNet] add _npi_subtract_scalar - add mxnet numpy operator, subtract - https://github.com/apache/tvm/issues/7186 - https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.subtract.html * Fix python style using black --- 3rdparty/vta-hw | 2 +- python/tvm/relay/frontend/mxnet.py | 2 ++ tests/python/frontend/mxnet/test_forward.py | 20 ++++++++++++++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw index 57db5a718c74..87ce9acfae55 160000 --- a/3rdparty/vta-hw +++ b/3rdparty/vta-hw @@ -1 +1 @@ -Subproject commit 57db5a718c74a788c98120ebbe1230797be698c8 +Subproject commit 87ce9acfae550d1a487746e9d06c2e250076e54c diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index f2330c72e1f4..1085e904c386 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -2693,6 +2693,8 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_multiply_scalar": _binop_scalar(_op.multiply), "_npi_add": _rename(_op.add), "_npi_add_scalar": _binop_scalar(_op.add), + "_npi_subtract": _rename(_op.subtract), + "_npi_subtract_scalar": _binop_scalar(_op.subtract), "_npi_where_rscalar": _mx_npi_where_rscalar, "_npi_less": _rename(_op.less), "_npi_less_equal": _mx_compare(_op.less_equal, _rename), diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index f076a27755ad..d3be8c0506ba 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -2062,8 +2062,14 @@ def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, @tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_binary(data_shape, dtype, target, ctx, kind): - ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.less] - mx_ops = [mx.sym.np.power, mx.sym.np.multiply, mx.sym.np.add, mx.sym.np.less] + ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.subtract, mx.np.less] + mx_ops = [ + mx.sym.np.power, + mx.sym.np.multiply, + mx.sym.np.add, + mx.sym.np.subtract, + mx.sym.np.less, + ] for i in range(len(ref_ops)): ref_op = ref_ops[i] mx_op = mx_ops[i] @@ -2092,8 +2098,14 @@ def test_forward_npi_binary(data_shape, dtype, target, ctx, kind): @pytest.mark.parametrize("scalar", [1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_binary_scalar(data_shape, dtype, scalar, target, ctx, kind): - ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.true_divide] - mx_ops = [mx.sym.np.power, mx.sym.np.multiply, mx.sym.np.add, mx.sym.np.true_divide] + ref_ops = [mx.np.power, mx.np.multiply, mx.np.add, mx.np.subtract, mx.np.true_divide] + mx_ops = [ + mx.sym.np.power, + mx.sym.np.multiply, + mx.sym.np.add, + mx.sym.np.subtract, + mx.sym.np.true_divide, + ] for i in range(len(ref_ops)): ref_op = ref_ops[i] mx_op = mx_ops[i]