From 72177628a9fc3e4fb3208dc942947363be236155 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 18 Oct 2020 15:27:46 -0700 Subject: [PATCH] [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP (#6699) * update Update type_relations.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update transform.cc Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py Update mxnet.py update Update mxnet.py debug Update generic.py Update topi_integration.py fix bug update Update test_forward.py Update test_forward.py fix test case Update mxnet.py update Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py Update mxnet.py Update mxnet.py debug Update mxnet.py Update mxnet.py Update test_forward.py Update mxnet.py * address comments * Update mxnet.py * Update mxnet.py * fix * improve where test * Update test_forward.py * Update test_forward.py * Update test_forward.py * update * Update mxnet.py * Update mxnet.py * Update mxnet.py debug Update common.py update Update mxnet.py update Update test_forward.py Update test_forward.py * update * fix lint * Update mxnet.py * Update test_op_level1.py * fix lint --- python/tvm/relay/frontend/mxnet.py | 133 +++++++++++++++----- python/tvm/topi/x86/batch_matmul.py | 12 +- tests/python/frontend/mxnet/test_forward.py | 30 +++-- tests/python/relay/test_op_level1.py | 2 +- 4 files changed, 134 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 984945f718688..a543f78bd9493 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -790,6 +790,16 @@ def _mx_dot(inputs, attrs): def _mx_batch_dot(inputs, attrs): assert len(inputs) == 2 a, b = inputs + a_shape = _infer_type(a).checked_type.shape + batch_shapes = None + if len(a_shape) > 3: + batch_shapes = a_shape[:-2] + a = _op.reverse_reshape(a, newshape=(-1, 0, 0)) + b_shape = _infer_type(b).checked_type.shape + if len(b_shape) > 3: + if batch_shapes is None: + batch_shapes = b_shape[:-2] + b = _op.reverse_reshape(b, newshape=(-1, 0, 0)) transpose_a = attrs.get_bool("transpose_a", False) transpose_b = attrs.get_bool("transpose_b", False) if transpose_a is True: @@ -797,7 +807,10 @@ def _mx_batch_dot(inputs, attrs): raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) if transpose_b is False: b = _op.transpose(b, axes=[0, 2, 1]) - return _op.nn.batch_matmul(a, b) + out = _op.nn.batch_matmul(a, b) + if batch_shapes is not None: + out = _op.reverse_reshape(out, newshape=tuple(batch_shapes) + (0, 0)) + return out def _mx_arange(inputs, attrs): @@ -2294,18 +2307,16 @@ def _mx_npi_pad(inputs, attrs): raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.') if pad_mode not in ["constant", "edge", "reflect"]: raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid') - pad_width = attrs.get_int_tuple("pad_width", None) - if pad_width is None: + if "pad_width" not in attrs.attrs: raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.') - if None in pad_width: - raise tvm.error.OpAttributeInvalid( - 'Value None in attribute "pad_width" of operator Slice is not valid.' - ) + # Begin to parse tuple of tuple, we cannot use get_int_tuple here because it's a tuple of tuple. + pad_width = attrs.attrs["pad_width"] + pad_width = pad_width.replace("(", "[") + pad_width = pad_width.replace(")", "]") + pad_width = json.loads(pad_width) constant_values = attrs.get_float("constant_values", 0.0) - padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2])) - return _op.nn.pad( - data=inputs[0], pad_width=padding, pad_value=constant_values, pad_mode=pad_mode + data=inputs[0], pad_width=pad_width, pad_value=constant_values, pad_mode=pad_mode ) @@ -2321,24 +2332,74 @@ def _mx_npx_reshape(inputs, attrs): shape = attrs.get_int_tuple("newshape") reverse = attrs.get_bool("reverse", False) shape_list = list(shape) - new_shape_list = [] - for num in shape_list: - if num > 0 or num == -1: - new_shape_list.append(num) - elif num == -2: - new_shape_list.append(0) - elif num == -4: - new_shape_list.append(-2) - elif num == -5: - new_shape_list.append(-3) - elif num == -6: - new_shape_list.append(-4) + old_shape = get_const_tuple(_infer_type(inputs[0]).checked_type.shape) + new_shape = [] + if reverse: + old_shape = old_shape[::-1] + shape_list = shape_list[::-1] + ptr = 0 + unknown_axis = None + src_ptr = 0 + while src_ptr < len(shape_list): + ele = shape_list[src_ptr] + src_ptr += 1 + if ele > 0: + new_shape.append(ele) + ptr += 1 + elif ele == -1: + new_shape.append(-1) + if unknown_axis is not None: + raise tvm.error.OpAttributeInvalid("Can only have one -1 in the input shape.") + unknown_axis = len(new_shape) + ptr += 1 + elif ele == -2: + new_shape.append(old_shape[ptr]) + ptr += 1 + elif ele == -3: + if old_shape[ptr] != 1: + raise tvm.error.OpAttributeInvalid( + "Dimension of the original shape " + "that corresponds to -3 must be 1. Received" + " {}".format(old_shape[ptr]) + ) + ptr += 1 + elif ele == -4: + new_shape += old_shape[ptr:] + break + elif ele == -5: + new_shape.append(old_shape[ptr] * old_shape[ptr + 1]) + ptr += 2 + elif ele == -6: + # Split axis + lhs = shape_list[src_ptr] + rhs = shape_list[src_ptr + 1] + src_ptr += 2 + if lhs == -1 and rhs == -1: + raise tvm.error.OpAttributeInvalid("The lhs and rhs can not both be -1.") + if lhs == -1: + if old_shape[ptr] % rhs != 0: + raise tvm.error.OpAttributeInvalid( + "When splitting the axis, " + "the dimension of the split axis must " + "be divisible by the splitted values." + ) + lhs = old_shape[ptr] // rhs + if rhs == -1: + if old_shape[ptr] % lhs != 0: + raise tvm.error.OpAttributeInvalid( + "When splitting the axis, " + "the dimension of the split axis must " + "be divisible by the splitted values." + ) + rhs = old_shape[ptr] // lhs + new_shape.append(lhs) + new_shape.append(rhs) + ptr += 1 else: - raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num) - shape = tuple(new_shape_list) + raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % ele) if reverse: - return _op.reverse_reshape(inputs[0], newshape=shape) - return _op.reshape(inputs[0], newshape=shape) + new_shape = new_shape[::-1] + return _op.reshape(inputs[0], newshape=new_shape) def _mx_split_v2(inputs, attrs): @@ -2356,12 +2417,21 @@ def _mx_split_v2(inputs, attrs): def _mx_npi_where_rscalar(inputs, attrs): + cond, dat = inputs scalar = attrs.get_float("scalar") - dtype = _infer_type(inputs[1]).checked_type.dtype + cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape) + dat_shape = get_const_tuple(_infer_type(dat).checked_type.shape) + dtype = _infer_type(dat).checked_type.dtype + # Check for broadcasting + out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape + if out_shape != cond_shape: + cond = _op.broadcast_to(cond, out_shape) + if out_shape != dat_shape: + dat = _op.broadcast_to(dat, out_shape) scalar = _expr.const(scalar, dtype=dtype) - ones = _op.ones_like(inputs[1]) + ones = _op.ones_like(dat) scalar = _op.multiply(ones, scalar) - return _op.where(inputs[0], inputs[1], scalar) + return _op.where(cond, dat, scalar) # Note: due to attribute conversion constraint @@ -2382,13 +2452,13 @@ def _mx_npi_where_rscalar(inputs, attrs): "reshape_like", "zeros_like", "ones_like", - "where", "cos", "cosh", "sin", "sinh", "tan", "tanh", + "where", ] _convert_map = { @@ -2609,6 +2679,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_concatenate": _mx_npi_concatenate, "_npx_reshape": _mx_npx_reshape, "_np_copy": _rename(_op.copy), + "_npi_copy": _rename(_op.copy), "_npi_power": _rename(_op.power), "_npi_power_scalar": _binop_scalar(_op.power), "_npi_multiply": _rename(_op.multiply), @@ -2617,6 +2688,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_add_scalar": _binop_scalar(_op.add), "_npi_where_rscalar": _mx_npi_where_rscalar, "_npi_less": _rename(_op.less), + "_npi_less_equal": _mx_compare(_op.less_equal, _rename), "_npi_tanh": _rename(_op.tanh), "_npi_true_divide_scalar": _binop_scalar(_op.divide), } @@ -2728,7 +2800,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): else: raise RuntimeError("unexpected type %s" % type(res)) node_map[nid] = res - outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _function.Function(analysis.free_vars(outputs), outputs) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index e3f08160509ef..4e5f6efc815ad 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -37,6 +37,9 @@ def batch_matmul(cfg, x, y, out_shape=None): 3-D with shape [batch, M, K] y : tvm.te.Tensor 3-D with shape [batch, N, K] + out_shape : tuple or None + Shape of the outputs + Returns ------- output : tvm.te.Tensor @@ -135,7 +138,7 @@ def _default_batch_matmul_config(cfg, M, N, K): @autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y): +def batch_matmul_cblas(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -147,6 +150,9 @@ def batch_matmul_cblas(cfg, x, y): 3-D with shape [batch, M, K] y : tvm.te.Tensor 3-D with shape [batch, N, K] + out_shape : tuple or None + Shape of the output + Returns ------- output : tvm.te.Tensor @@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y): YB, N, YK = get_const_tuple(y.shape) assert XB == YB, "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistant" + if out_shape is not None: + assert out_shape[0] == XB, "got invalid output shape" + assert out_shape[1] == M, "got invalid output shape" + assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) return cblas.batch_matmul(x, y, False, True) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 44307f4e60fe6..79c587fc7f9e4 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1932,7 +1932,10 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet") @pytest.mark.parametrize( "data_shape, pad_width", - [((1, 1, 3, 5), (0, 0, 0, 0, 1, 2, 3, 4)), ((1, 1, 3, 5, 7), (0, 0, 0, 0, 1, 2, 3, 4, 5, 6))], + [ + ((1, 1, 3, 5), ((0, 0), (0, 0), (1, 2), (3, 4))), + ((1, 1, 3, 5, 7), ((0, 0), (0, 0), (1, 2), (3, 4), (5, 6))), + ], ) @pytest.mark.parametrize("mode", ["constant", "edge", "reflect"]) @pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"]) @@ -1943,19 +1946,17 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar data_np = np.random.uniform(size=data_shape).astype(dtype) data = mx.sym.var("data") if mode == "constant": - ref_res = mx.ndarray.pad( - mx.nd.array(data_np), mode=mode, pad_width=pad_width, constant_value=constant_value - ) + ref_res = np.pad(data_np, mode=mode, pad_width=pad_width, constant_values=constant_value) mx_sym = mx.sym.np.pad( data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value ) else: - ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode, pad_width=pad_width) + ref_res = np.pad(data_np, mode=mode, pad_width=pad_width) mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data_np) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) @pytest.mark.skipif( @@ -2029,8 +2030,12 @@ def test_forward_np_copy(data_shape, dtype, target, ctx, kind): ((2, 3, 8), (-2, -2, 2, -1), False), ((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False), ((8, 3, 3, 3, 4, 4), (-5, -4), False), + ((1, 8, 3, 3, 3, 4, 4), (-3, -5, -4), False), + ((8, 1, 3, 4), (-2, -3, -1), False), ((8, 3, 3, 3, 3, 8), (-4, -5), True), ((8, 3, 2, 4, 8), (-4, -1, 2, -6), True), + ((3, 2, 4, 8, 1, 1), (-4, -1, 2, -6, -5, -3), True), + ((2, 4, 1, 8), (-4, -3, -1, 2, -6), True), ], ) def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind): @@ -2117,16 +2122,21 @@ def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind): @pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet") -@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)]) +@pytest.mark.parametrize( + "data_shape,cond_shape", + [[(2, 2, 2), (2, 2, 2)], [(2, 7, 2), (7, 2)], [(2, 2), (1, 2)], [(1, 3), (3, 3)]], +) @pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"]) @pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"]) @pytest.mark.parametrize("scalar", [1.0, 2.0]) @tvm.testing.parametrize_targets @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) -def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind): +def test_forward_npi_where_rscalar( + data_shape, cond_shape, data_dtype, cond_dtype, scalar, target, ctx, kind +): if data_dtype == "bool": scalar = scalar == 0.0 - cond_np = np.random.uniform(size=data_shape).astype(cond_dtype) + cond_np = np.random.uniform(size=cond_shape).astype(cond_dtype) data_np = np.random.uniform(size=data_shape).astype(data_dtype) cond = mx.sym.var("condition") data = mx.sym.var("x") @@ -2136,7 +2146,7 @@ def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, t dtypeDic["condition"] = cond_dtype dtypeDic["x"] = data_dtype mod, _ = relay.frontend.from_mxnet( - mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic + mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic ) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(cond_np, data_np) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 8c724daaa9d04..37a59c30f4107 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -134,7 +134,7 @@ def check_binary_op(opfunc, ref, dtype): continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01, atol=1e-3) for opfunc, ref in [ (relay.add, np.add),