From 0476ea8f883a7a89778a38ef19683354ee489fc0 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Wed, 6 Jan 2021 19:56:17 -0800 Subject: [PATCH 1/3] Added _npi_advanced_indexing_multiple - need to add test case --- python/tvm/relay/frontend/mxnet.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b272ead9737d..733c1fcd2ab0 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -2342,6 +2342,11 @@ def _mx_npi_stack(inputs, attrs): else: return _op.stack(tuple(inputs), axis=int(axis)) +def mx_npi_advanced_indexing_multiple(inputs, attrs): + assert len(inputs) == 2 + data = inputs[0] + indices = inputs[1] + return _op.adv_index([data] + indices) def _mx_npx_reshape(inputs, attrs): shape = attrs.get_int_tuple("newshape") @@ -2709,6 +2714,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_tanh": _rename(_op.tanh), "_npi_true_divide_scalar": _binop_scalar(_op.divide), "_npi_stack": _mx_npi_stack, + "_npi_advanced_indexing_multiple": mx_npi_advanced_indexing_multiple, } # set identity list From 6991f937e8384f3e605a4ba70308fcacf29b85eb Mon Sep 17 00:00:00 2001 From: Insop Song Date: Thu, 7 Jan 2021 22:29:15 -0800 Subject: [PATCH 2/3] Add test case for _npi_advanced_indexing_multiple - TODO: need to find a proper symbol for comparison - currently test function is NOT valid --- tests/python/frontend/mxnet/test_forward.py | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 537349e073e1..8a1763266a7d 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1935,6 +1935,29 @@ def verify(data_shape, axis, use_length, length): verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype("int32")) +@pytest.mark.parametrize( + "data_shape, row_sel, col", + [ + ((5, 7), (0, 1, 2, 3, 4,), 2), + ], +) +@pytest.mark.parametrize("dtype", ["float64", "float32"]) +@tvm.testing.parametrize_targets +@pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) +def test_forward_npi_advanced_indexing_multiple(data_shape, row_sel, col, dtype, target, ctx, kind): + data_np = np.random.uniform(size=data_shape).astype(dtype) + data = mx.sym.var("data") + ref_res = mx.np.array(data_np)[row_sel, col] + + # TODO need to add the proper symbol operator + mx_sym = mx.sym.np.(data.as_np_ndarray()[row_sel, col]) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + tvm.testing.assert_allclose(ref_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + + @pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet") @pytest.mark.parametrize( "data_shape, pad_width", From 4c763cd443f9faf6654769f4d5528143fb772703 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Tue, 12 Jan 2021 02:23:42 -0800 Subject: [PATCH 3/3] updated based on the PR discussion --- tests/python/frontend/mxnet/test_forward.py | 24 ++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 8a1763266a7d..eee0e95d2505 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1938,7 +1938,17 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.parametrize( "data_shape, row_sel, col", [ - ((5, 7), (0, 1, 2, 3, 4,), 2), + ( + (5, 7), + ( + 0, + 1, + 2, + 3, + 4, + ), + 2, + ), ], ) @pytest.mark.parametrize("dtype", ["float64", "float32"]) @@ -1946,12 +1956,16 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_advanced_indexing_multiple(data_shape, row_sel, col, dtype, target, ctx, kind): data_np = np.random.uniform(size=data_shape).astype(dtype) - data = mx.sym.var("data") ref_res = mx.np.array(data_np)[row_sel, col] - # TODO need to add the proper symbol operator - mx_sym = mx.sym.np.(data.as_np_ndarray()[row_sel, col]) - mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype) + row_sel_sym = mx.sym.var("row_sel").as_np_ndarray() + data_sym = mx.sym.var("data").as_np_ndarray() + col_sym = mx.sym.var("col").as_np_ndarray() + mx_sym = data_sym[row_sel_sym, col_sym] + + mod, _ = relay.frontend.from_mxnet( + mx_sym, shape={"data": data_shape, "row_sel": row_sel, "col": col}, dtype=dtype + ) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)