Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP #6699

Merged
merged 18 commits into from
Oct 18, 2020
141 changes: 118 additions & 23 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
_activation_map = {"sigmoid": _op.sigmoid, "tanh": _op.tanh, "relu": _op.nn.relu}


def get_tuple_shape(shape_expr):
comaniac marked this conversation as resolved.
Show resolved Hide resolved
"""Get the tuple shape from a shape expression"""
return tuple([ele.value for ele in shape_expr])


def _mx_fully_connected(inputs, attrs):
import mxnet as mx # pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -627,6 +632,21 @@ def _mx_expand_dims(inputs, attrs):
return _op.expand_dims(inputs[0], axis=axis)


def _mx_where(inputs, attrs):
sxjscience marked this conversation as resolved.
Show resolved Hide resolved
cond, lhs, rhs = inputs
cond_shape = get_tuple_shape(_infer_type(cond).checked_type.shape)
lhs_shape = get_tuple_shape(_infer_type(lhs).checked_type.shape)
rhs_shape = get_tuple_shape(_infer_type(rhs).checked_type.shape)
out_shape = np.broadcast(np.empty(cond_shape), np.empty(lhs_shape), np.empty(rhs_shape)).shape
if out_shape != cond_shape:
cond = _op.broadcast_to(cond, out_shape)
if out_shape != lhs_shape:
lhs = _op.broadcast_to(lhs, out_shape)
if out_shape != rhs_shape:
rhs = _op.broadcast_to(rhs, out_shape)
return _op.where(cond, lhs, rhs)


def _mx_pad(inputs, attrs):
pad_mode = attrs.get_str("mode", None)
if pad_mode is None:
Expand Down Expand Up @@ -790,14 +810,27 @@ def _mx_dot(inputs, attrs):
def _mx_batch_dot(inputs, attrs):
assert len(inputs) == 2
a, b = inputs
a_shape = _infer_type(a).checked_type.shape
batch_shapes = None
if len(a_shape) > 3:
batch_shapes = a_shape[:-2]
a = _op.reverse_reshape(a, newshape=(-1, 0, 0))
b_shape = _infer_type(b).checked_type.shape
if len(b_shape) > 3:
if batch_shapes is None:
batch_shapes = b_shape[:-2]
b = _op.reverse_reshape(b, newshape=(-1, 0, 0))
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' "is not valid."
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.nn.batch_matmul(a, b)
out = _op.nn.batch_matmul(a, b)
if batch_shapes is not None:
out = _op.reverse_reshape(out, newshape=tuple(batch_shapes) + (0, 0))
return out


def _mx_arange(inputs, attrs):
Expand Down Expand Up @@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
reverse = attrs.get_bool("reverse", False)
shape_list = list(shape)
new_shape_list = []
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
else:
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
shape = tuple(new_shape_list)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)
if -3 not in shape_list:
comaniac marked this conversation as resolved.
Show resolved Hide resolved
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
comaniac marked this conversation as resolved.
Show resolved Hide resolved
else:
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
shape = tuple(new_shape_list)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)
else:
old_shape = get_tuple_shape(_infer_type(inputs[0]).checked_type.shape)
new_shape = []
if reverse:
old_shape = old_shape[::-1]
shape_list = shape_list[::-1]
ptr = 0
unknown_axis = None
src_ptr = 0
while src_ptr < len(shape_list):
ele = shape_list[src_ptr]
src_ptr += 1
if ele > 0:
new_shape.append(ele)
ptr += 1
elif ele == -1:
new_shape.append(-1)
assert unknown_axis is None, "Can only have one unknown axis."
unknown_axis = len(new_shape)
ptr += 1
elif ele == -2:
new_shape.append(old_shape[ptr])
ptr += 1
elif ele == -3:
assert old_shape[ptr] == 1
comaniac marked this conversation as resolved.
Show resolved Hide resolved
ptr += 1
elif ele == -4:
new_shape += old_shape[ptr:]
break
elif ele == -5:
new_shape.append(old_shape[ptr] * old_shape[ptr + 1])
ptr += 2
elif ele == -6:
# Split axis
lhs = shape_list[src_ptr]
rhs = shape_list[src_ptr + 1]
src_ptr += 2
assert not (lhs == -1 and rhs == -1)
if lhs == -1:
assert old_shape[ptr] % rhs == 0
lhs = old_shape[ptr] // rhs
if rhs == -1:
assert old_shape[ptr] % lhs == 0
rhs = old_shape[ptr] // lhs
new_shape.append(lhs)
new_shape.append(rhs)
ptr += 1
else:
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % ele)
if reverse:
new_shape = new_shape[::-1]
return _op.reshape(inputs[0], newshape=new_shape)


def _mx_split_v2(inputs, attrs):
Expand All @@ -2346,12 +2432,21 @@ def _mx_split_v2(inputs, attrs):


def _mx_npi_where_rscalar(inputs, attrs):
cond, dat = inputs
scalar = attrs.get_float("scalar")
dtype = _infer_type(inputs[1]).checked_type.dtype
cond_shape = get_tuple_shape(_infer_type(cond).checked_type.shape)
dat_shape = get_tuple_shape(_infer_type(dat).checked_type.shape)
dtype = _infer_type(dat).checked_type.dtype
# Check for broadcasting
out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape
if out_shape != cond_shape:
cond = _op.broadcast_to(cond, out_shape)
if out_shape != dat_shape:
dat = _op.broadcast_to(dat, out_shape)
scalar = _expr.const(scalar, dtype=dtype)
ones = _op.ones_like(inputs[1])
ones = _op.ones_like(dat)
scalar = _op.multiply(ones, scalar)
return _op.where(inputs[0], inputs[1], scalar)
return _op.where(cond, dat, scalar)


# Note: due to attribute conversion constraint
Expand All @@ -2372,7 +2467,6 @@ def _mx_npi_where_rscalar(inputs, attrs):
"reshape_like",
"zeros_like",
"ones_like",
"where",
"cos",
"cosh",
"sin",
Expand All @@ -2384,6 +2478,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
_convert_map = {
"_copy": _rename(_op.copy),
"relu": _rename(_op.nn.relu),
"where": _mx_where,
"broadcast_add": _rename(_op.add),
"broadcast_plus": _rename(_op.add),
"broadcast_sub": _rename(_op.subtract),
Expand Down Expand Up @@ -2598,6 +2693,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_concatenate": _mx_npi_concatenate,
"_npx_reshape": _mx_npx_reshape,
"_np_copy": _rename(_op.copy),
"_npi_copy": _rename(_op.copy),
"_npi_power": _rename(_op.power),
"_npi_power_scalar": _binop_scalar(_op.power),
"_npi_multiply": _rename(_op.multiply),
Expand Down Expand Up @@ -2717,7 +2813,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
else:
raise RuntimeError("unexpected type %s" % type(res))
node_map[nid] = res

outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the outputs

Returns
-------
output : tvm.te.Tensor
Expand Down Expand Up @@ -135,7 +138,7 @@ def _default_batch_matmul_config(cfg, M, N, K):


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y):
def batch_matmul_cblas(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Expand All @@ -147,6 +150,9 @@ def batch_matmul_cblas(cfg, x, y):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output

Returns
-------
output : tvm.te.Tensor
Expand All @@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
if out_shape is not None:
assert out_shape[0] == XB, "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
comaniac marked this conversation as resolved.
Show resolved Hide resolved
cfg.add_flop(XB * M * N * XK * 2)
return cblas.batch_matmul(x, y, False, True)

Expand Down
22 changes: 18 additions & 4 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,8 +2011,12 @@ def test_forward_np_copy(data_shape, dtype, target, ctx, kind):
((2, 3, 8), (-2, -2, 2, -1), False),
((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False),
((8, 3, 3, 3, 4, 4), (-5, -4), False),
((1, 8, 3, 3, 3, 4, 4), (-3, -5, -4), False),
((8, 1, 3, 4), (-2, -3, -1), False),
((8, 3, 3, 3, 3, 8), (-4, -5), True),
((8, 3, 2, 4, 8), (-4, -1, 2, -6), True),
((3, 2, 4, 8, 1, 1), (-4, -1, 2, -6, -5, -3), True),
((2, 4, 1, 8), (-4, -3, -1, 2, -6), True),
],
)
def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind):
Expand Down Expand Up @@ -2099,16 +2103,26 @@ def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind):


@pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet")
@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)])
@pytest.mark.parametrize(
"data_shape,cond_shape",
[
[(2, 2, 2), (2, 2, 2)],
[(2, 7, 2), (7, 2)],
[(2, 2), (1, 2)],
[(1, 3), (3, 3)]
]
)
@pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("scalar", [1.0, 2.0])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind):
def test_forward_npi_where_rscalar(
data_shape, cond_shape, data_dtype, cond_dtype, scalar, target, ctx, kind
):
if data_dtype == "bool":
scalar = scalar == 0.0
cond_np = np.random.uniform(size=data_shape).astype(cond_dtype)
cond_np = np.random.uniform(size=cond_shape).astype(cond_dtype)
data_np = np.random.uniform(size=data_shape).astype(data_dtype)
cond = mx.sym.var("condition")
data = mx.sym.var("x")
Expand All @@ -2118,7 +2132,7 @@ def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, t
dtypeDic["condition"] = cond_dtype
dtypeDic["x"] = data_dtype
mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic
mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic
)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(cond_np, data_np)
Expand Down