Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP #6699

Merged
merged 18 commits into from
Oct 18, 2020
133 changes: 102 additions & 31 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,14 +790,27 @@ def _mx_dot(inputs, attrs):
def _mx_batch_dot(inputs, attrs):
assert len(inputs) == 2
a, b = inputs
a_shape = _infer_type(a).checked_type.shape
batch_shapes = None
if len(a_shape) > 3:
batch_shapes = a_shape[:-2]
a = _op.reverse_reshape(a, newshape=(-1, 0, 0))
b_shape = _infer_type(b).checked_type.shape
if len(b_shape) > 3:
if batch_shapes is None:
batch_shapes = b_shape[:-2]
b = _op.reverse_reshape(b, newshape=(-1, 0, 0))
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' "is not valid."
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.nn.batch_matmul(a, b)
out = _op.nn.batch_matmul(a, b)
if batch_shapes is not None:
out = _op.reverse_reshape(out, newshape=tuple(batch_shapes) + (0, 0))
return out


def _mx_arange(inputs, attrs):
Expand Down Expand Up @@ -2284,18 +2297,16 @@ def _mx_npi_pad(inputs, attrs):
raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.')
if pad_mode not in ["constant", "edge", "reflect"]:
raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple("pad_width", None)
if pad_width is None:
if "pad_width" not in attrs.attrs:
raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.'
)
# Begin to parse tuple of tuple, we cannot use get_int_tuple here because it's a tuple of tuple.
pad_width = attrs.attrs["pad_width"]
pad_width = pad_width.replace("(", "[")
pad_width = pad_width.replace(")", "]")
pad_width = json.loads(pad_width)
constant_values = attrs.get_float("constant_values", 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))

return _op.nn.pad(
data=inputs[0], pad_width=padding, pad_value=constant_values, pad_mode=pad_mode
data=inputs[0], pad_width=pad_width, pad_value=constant_values, pad_mode=pad_mode
)


Expand All @@ -2311,24 +2322,74 @@ def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
shape_list = list(shape)
new_shape_list = []
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
old_shape = get_const_tuple(_infer_type(inputs[0]).checked_type.shape)
new_shape = []
if reverse:
old_shape = old_shape[::-1]
shape_list = shape_list[::-1]
ptr = 0
unknown_axis = None
src_ptr = 0
while src_ptr < len(shape_list):
ele = shape_list[src_ptr]
src_ptr += 1
if ele > 0:
new_shape.append(ele)
ptr += 1
elif ele == -1:
new_shape.append(-1)
if unknown_axis is not None:
raise tvm.error.OpAttributeInvalid("Can only have one -1 in the input shape.")
unknown_axis = len(new_shape)
ptr += 1
elif ele == -2:
new_shape.append(old_shape[ptr])
ptr += 1
elif ele == -3:
if old_shape[ptr] != 1:
raise tvm.error.OpAttributeInvalid(
"Dimension of the original shape "
"that corresponds to -3 must be 1. Received"
" {}".format(old_shape[ptr])
)
ptr += 1
elif ele == -4:
new_shape += old_shape[ptr:]
break
elif ele == -5:
new_shape.append(old_shape[ptr] * old_shape[ptr + 1])
ptr += 2
elif ele == -6:
# Split axis
lhs = shape_list[src_ptr]
rhs = shape_list[src_ptr + 1]
src_ptr += 2
if lhs == -1 and rhs == -1:
raise tvm.error.OpAttributeInvalid("The lhs and rhs can not both be -1.")
if lhs == -1:
if old_shape[ptr] % rhs != 0:
raise tvm.error.OpAttributeInvalid(
"When splitting the axis, "
"the dimension of the split axis must "
"be divisible by the splitted values."
)
lhs = old_shape[ptr] // rhs
if rhs == -1:
if old_shape[ptr] % lhs != 0:
raise tvm.error.OpAttributeInvalid(
"When splitting the axis, "
"the dimension of the split axis must "
"be divisible by the splitted values."
)
rhs = old_shape[ptr] // lhs
new_shape.append(lhs)
new_shape.append(rhs)
ptr += 1
else:
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
shape = tuple(new_shape_list)
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % ele)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)
new_shape = new_shape[::-1]
return _op.reshape(inputs[0], newshape=new_shape)


def _mx_split_v2(inputs, attrs):
Expand All @@ -2346,12 +2407,21 @@ def _mx_split_v2(inputs, attrs):


def _mx_npi_where_rscalar(inputs, attrs):
cond, dat = inputs
scalar = attrs.get_float("scalar")
dtype = _infer_type(inputs[1]).checked_type.dtype
cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
dat_shape = get_const_tuple(_infer_type(dat).checked_type.shape)
dtype = _infer_type(dat).checked_type.dtype
# Check for broadcasting
out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape
if out_shape != cond_shape:
cond = _op.broadcast_to(cond, out_shape)
if out_shape != dat_shape:
dat = _op.broadcast_to(dat, out_shape)
scalar = _expr.const(scalar, dtype=dtype)
ones = _op.ones_like(inputs[1])
ones = _op.ones_like(dat)
scalar = _op.multiply(ones, scalar)
return _op.where(inputs[0], inputs[1], scalar)
return _op.where(cond, dat, scalar)


# Note: due to attribute conversion constraint
Expand All @@ -2372,13 +2442,13 @@ def _mx_npi_where_rscalar(inputs, attrs):
"reshape_like",
"zeros_like",
"ones_like",
"where",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"where",
]

_convert_map = {
Expand Down Expand Up @@ -2598,6 +2668,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_concatenate": _mx_npi_concatenate,
"_npx_reshape": _mx_npx_reshape,
"_np_copy": _rename(_op.copy),
"_npi_copy": _rename(_op.copy),
"_npi_power": _rename(_op.power),
"_npi_power_scalar": _binop_scalar(_op.power),
"_npi_multiply": _rename(_op.multiply),
Expand All @@ -2606,6 +2677,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_add_scalar": _binop_scalar(_op.add),
"_npi_where_rscalar": _mx_npi_where_rscalar,
"_npi_less": _rename(_op.less),
"_npi_less_equal": _mx_compare(_op.less_equal, _rename),
"_npi_tanh": _rename(_op.tanh),
"_npi_true_divide_scalar": _binop_scalar(_op.divide),
}
Expand Down Expand Up @@ -2717,7 +2789,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
else:
raise RuntimeError("unexpected type %s" % type(res))
node_map[nid] = res

outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the outputs
Returns
-------
output : tvm.te.Tensor
Expand Down Expand Up @@ -135,7 +138,7 @@ def _default_batch_matmul_config(cfg, M, N, K):


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y):
def batch_matmul_cblas(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Expand All @@ -147,6 +150,9 @@ def batch_matmul_cblas(cfg, x, y):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output
Returns
-------
output : tvm.te.Tensor
Expand All @@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
if out_shape is not None:
assert out_shape[0] == XB, "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
comaniac marked this conversation as resolved.
Show resolved Hide resolved
cfg.add_flop(XB * M * N * XK * 2)
return cblas.batch_matmul(x, y, False, True)

Expand Down
30 changes: 20 additions & 10 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,10 @@ def verify(data_shape, axis, use_length, length):
@pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet")
@pytest.mark.parametrize(
"data_shape, pad_width",
[((1, 1, 3, 5), (0, 0, 0, 0, 1, 2, 3, 4)), ((1, 1, 3, 5, 7), (0, 0, 0, 0, 1, 2, 3, 4, 5, 6))],
[
((1, 1, 3, 5), ((0, 0), (0, 0), (1, 2), (3, 4))),
((1, 1, 3, 5, 7), ((0, 0), (0, 0), (1, 2), (3, 4), (5, 6))),
],
)
@pytest.mark.parametrize("mode", ["constant", "edge", "reflect"])
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
Expand All @@ -1925,19 +1928,17 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar
data_np = np.random.uniform(size=data_shape).astype(dtype)
data = mx.sym.var("data")
if mode == "constant":
ref_res = mx.ndarray.pad(
mx.nd.array(data_np), mode=mode, pad_width=pad_width, constant_value=constant_value
)
ref_res = np.pad(data_np, mode=mode, pad_width=pad_width, constant_values=constant_value)
mx_sym = mx.sym.np.pad(
data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value
)
else:
ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode, pad_width=pad_width)
ref_res = np.pad(data_np, mode=mode, pad_width=pad_width)
mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)


@pytest.mark.skipif(
Expand Down Expand Up @@ -2011,8 +2012,12 @@ def test_forward_np_copy(data_shape, dtype, target, ctx, kind):
((2, 3, 8), (-2, -2, 2, -1), False),
((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False),
((8, 3, 3, 3, 4, 4), (-5, -4), False),
((1, 8, 3, 3, 3, 4, 4), (-3, -5, -4), False),
((8, 1, 3, 4), (-2, -3, -1), False),
((8, 3, 3, 3, 3, 8), (-4, -5), True),
((8, 3, 2, 4, 8), (-4, -1, 2, -6), True),
((3, 2, 4, 8, 1, 1), (-4, -1, 2, -6, -5, -3), True),
((2, 4, 1, 8), (-4, -3, -1, 2, -6), True),
],
)
def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind):
Expand Down Expand Up @@ -2099,16 +2104,21 @@ def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind):


@pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet")
@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)])
@pytest.mark.parametrize(
"data_shape,cond_shape",
[[(2, 2, 2), (2, 2, 2)], [(2, 7, 2), (7, 2)], [(2, 2), (1, 2)], [(1, 3), (3, 3)]],
)
@pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("scalar", [1.0, 2.0])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind):
def test_forward_npi_where_rscalar(
data_shape, cond_shape, data_dtype, cond_dtype, scalar, target, ctx, kind
):
if data_dtype == "bool":
scalar = scalar == 0.0
cond_np = np.random.uniform(size=data_shape).astype(cond_dtype)
cond_np = np.random.uniform(size=cond_shape).astype(cond_dtype)
data_np = np.random.uniform(size=data_shape).astype(data_dtype)
cond = mx.sym.var("condition")
data = mx.sym.var("x")
Expand All @@ -2118,7 +2128,7 @@ def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, t
dtypeDic["condition"] = cond_dtype
dtypeDic["x"] = data_dtype
mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic
mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic
)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(cond_np, data_np)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def check_binary_op(opfunc, ref, dtype):
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01, atol=1e-3)

for opfunc, ref in [
(relay.add, np.add),
Expand Down