Skip to content

Commit

Permalink
[Frontend][Relay] Fix MXNet frontend to support NLP backbones in Gluo…
Browse files Browse the repository at this point in the history
…nNLP (apache#6699)

* update

Update type_relations.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

update

Update mxnet.py

debug

Update generic.py

Update topi_integration.py

fix bug

update

Update test_forward.py

Update test_forward.py

fix test case

Update mxnet.py

update

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

debug

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

* address comments

* Update mxnet.py

* Update mxnet.py

* fix

* improve where test

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* update

* Update mxnet.py

* Update mxnet.py

* Update mxnet.py

debug

Update common.py

update

Update mxnet.py

update

Update test_forward.py

Update test_forward.py

* update

* fix lint

* Update mxnet.py

* Update test_op_level1.py

* fix lint
  • Loading branch information
sxjscience authored and Trevor Morris committed Oct 28, 2020
1 parent b7ab5ab commit 7217762
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 43 deletions.
133 changes: 102 additions & 31 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,14 +790,27 @@ def _mx_dot(inputs, attrs):
def _mx_batch_dot(inputs, attrs):
assert len(inputs) == 2
a, b = inputs
a_shape = _infer_type(a).checked_type.shape
batch_shapes = None
if len(a_shape) > 3:
batch_shapes = a_shape[:-2]
a = _op.reverse_reshape(a, newshape=(-1, 0, 0))
b_shape = _infer_type(b).checked_type.shape
if len(b_shape) > 3:
if batch_shapes is None:
batch_shapes = b_shape[:-2]
b = _op.reverse_reshape(b, newshape=(-1, 0, 0))
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' "is not valid."
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.nn.batch_matmul(a, b)
out = _op.nn.batch_matmul(a, b)
if batch_shapes is not None:
out = _op.reverse_reshape(out, newshape=tuple(batch_shapes) + (0, 0))
return out


def _mx_arange(inputs, attrs):
Expand Down Expand Up @@ -2294,18 +2307,16 @@ def _mx_npi_pad(inputs, attrs):
raise tvm.error.OpAttributeRequired('Attribute "mode" not found in operator pad.')
if pad_mode not in ["constant", "edge", "reflect"]:
raise tvm.error.OpAttributeInvalid("Value " + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple("pad_width", None)
if pad_width is None:
if "pad_width" not in attrs.attrs:
raise tvm.error.OpAttributeRequired('Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.'
)
# Begin to parse tuple of tuple, we cannot use get_int_tuple here because it's a tuple of tuple.
pad_width = attrs.attrs["pad_width"]
pad_width = pad_width.replace("(", "[")
pad_width = pad_width.replace(")", "]")
pad_width = json.loads(pad_width)
constant_values = attrs.get_float("constant_values", 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))

return _op.nn.pad(
data=inputs[0], pad_width=padding, pad_value=constant_values, pad_mode=pad_mode
data=inputs[0], pad_width=pad_width, pad_value=constant_values, pad_mode=pad_mode
)


Expand All @@ -2321,24 +2332,74 @@ def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
shape_list = list(shape)
new_shape_list = []
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
old_shape = get_const_tuple(_infer_type(inputs[0]).checked_type.shape)
new_shape = []
if reverse:
old_shape = old_shape[::-1]
shape_list = shape_list[::-1]
ptr = 0
unknown_axis = None
src_ptr = 0
while src_ptr < len(shape_list):
ele = shape_list[src_ptr]
src_ptr += 1
if ele > 0:
new_shape.append(ele)
ptr += 1
elif ele == -1:
new_shape.append(-1)
if unknown_axis is not None:
raise tvm.error.OpAttributeInvalid("Can only have one -1 in the input shape.")
unknown_axis = len(new_shape)
ptr += 1
elif ele == -2:
new_shape.append(old_shape[ptr])
ptr += 1
elif ele == -3:
if old_shape[ptr] != 1:
raise tvm.error.OpAttributeInvalid(
"Dimension of the original shape "
"that corresponds to -3 must be 1. Received"
" {}".format(old_shape[ptr])
)
ptr += 1
elif ele == -4:
new_shape += old_shape[ptr:]
break
elif ele == -5:
new_shape.append(old_shape[ptr] * old_shape[ptr + 1])
ptr += 2
elif ele == -6:
# Split axis
lhs = shape_list[src_ptr]
rhs = shape_list[src_ptr + 1]
src_ptr += 2
if lhs == -1 and rhs == -1:
raise tvm.error.OpAttributeInvalid("The lhs and rhs can not both be -1.")
if lhs == -1:
if old_shape[ptr] % rhs != 0:
raise tvm.error.OpAttributeInvalid(
"When splitting the axis, "
"the dimension of the split axis must "
"be divisible by the splitted values."
)
lhs = old_shape[ptr] // rhs
if rhs == -1:
if old_shape[ptr] % lhs != 0:
raise tvm.error.OpAttributeInvalid(
"When splitting the axis, "
"the dimension of the split axis must "
"be divisible by the splitted values."
)
rhs = old_shape[ptr] // lhs
new_shape.append(lhs)
new_shape.append(rhs)
ptr += 1
else:
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
shape = tuple(new_shape_list)
raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % ele)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)
new_shape = new_shape[::-1]
return _op.reshape(inputs[0], newshape=new_shape)


def _mx_split_v2(inputs, attrs):
Expand All @@ -2356,12 +2417,21 @@ def _mx_split_v2(inputs, attrs):


def _mx_npi_where_rscalar(inputs, attrs):
cond, dat = inputs
scalar = attrs.get_float("scalar")
dtype = _infer_type(inputs[1]).checked_type.dtype
cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
dat_shape = get_const_tuple(_infer_type(dat).checked_type.shape)
dtype = _infer_type(dat).checked_type.dtype
# Check for broadcasting
out_shape = np.broadcast(np.empty(cond_shape), np.empty(dat_shape)).shape
if out_shape != cond_shape:
cond = _op.broadcast_to(cond, out_shape)
if out_shape != dat_shape:
dat = _op.broadcast_to(dat, out_shape)
scalar = _expr.const(scalar, dtype=dtype)
ones = _op.ones_like(inputs[1])
ones = _op.ones_like(dat)
scalar = _op.multiply(ones, scalar)
return _op.where(inputs[0], inputs[1], scalar)
return _op.where(cond, dat, scalar)


# Note: due to attribute conversion constraint
Expand All @@ -2382,13 +2452,13 @@ def _mx_npi_where_rscalar(inputs, attrs):
"reshape_like",
"zeros_like",
"ones_like",
"where",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"where",
]

_convert_map = {
Expand Down Expand Up @@ -2609,6 +2679,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_concatenate": _mx_npi_concatenate,
"_npx_reshape": _mx_npx_reshape,
"_np_copy": _rename(_op.copy),
"_npi_copy": _rename(_op.copy),
"_npi_power": _rename(_op.power),
"_npi_power_scalar": _binop_scalar(_op.power),
"_npi_multiply": _rename(_op.multiply),
Expand All @@ -2617,6 +2688,7 @@ def _mx_npi_where_rscalar(inputs, attrs):
"_npi_add_scalar": _binop_scalar(_op.add),
"_npi_where_rscalar": _mx_npi_where_rscalar,
"_npi_less": _rename(_op.less),
"_npi_less_equal": _mx_compare(_op.less_equal, _rename),
"_npi_tanh": _rename(_op.tanh),
"_npi_true_divide_scalar": _binop_scalar(_op.divide),
}
Expand Down Expand Up @@ -2728,7 +2800,6 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
else:
raise RuntimeError("unexpected type %s" % type(res))
node_map[nid] = res

outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the outputs
Returns
-------
output : tvm.te.Tensor
Expand Down Expand Up @@ -135,7 +138,7 @@ def _default_batch_matmul_config(cfg, M, N, K):


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y):
def batch_matmul_cblas(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Expand All @@ -147,6 +150,9 @@ def batch_matmul_cblas(cfg, x, y):
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output
Returns
-------
output : tvm.te.Tensor
Expand All @@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
if out_shape is not None:
assert out_shape[0] == XB, "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
return cblas.batch_matmul(x, y, False, True)

Expand Down
30 changes: 20 additions & 10 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,10 @@ def verify(data_shape, axis, use_length, length):
@pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet")
@pytest.mark.parametrize(
"data_shape, pad_width",
[((1, 1, 3, 5), (0, 0, 0, 0, 1, 2, 3, 4)), ((1, 1, 3, 5, 7), (0, 0, 0, 0, 1, 2, 3, 4, 5, 6))],
[
((1, 1, 3, 5), ((0, 0), (0, 0), (1, 2), (3, 4))),
((1, 1, 3, 5, 7), ((0, 0), (0, 0), (1, 2), (3, 4), (5, 6))),
],
)
@pytest.mark.parametrize("mode", ["constant", "edge", "reflect"])
@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"])
Expand All @@ -1943,19 +1946,17 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar
data_np = np.random.uniform(size=data_shape).astype(dtype)
data = mx.sym.var("data")
if mode == "constant":
ref_res = mx.ndarray.pad(
mx.nd.array(data_np), mode=mode, pad_width=pad_width, constant_value=constant_value
)
ref_res = np.pad(data_np, mode=mode, pad_width=pad_width, constant_values=constant_value)
mx_sym = mx.sym.np.pad(
data.as_np_ndarray(), mode=mode, pad_width=pad_width, constant_values=constant_value
)
else:
ref_res = mx.ndarray.pad(mx.nd.array(data_np), mode=mode, pad_width=pad_width)
ref_res = np.pad(data_np, mode=mode, pad_width=pad_width)
mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)


@pytest.mark.skipif(
Expand Down Expand Up @@ -2029,8 +2030,12 @@ def test_forward_np_copy(data_shape, dtype, target, ctx, kind):
((2, 3, 8), (-2, -2, 2, -1), False),
((8, 3, 3, 3, 4, 4), (-6, 2, -1, -4), False),
((8, 3, 3, 3, 4, 4), (-5, -4), False),
((1, 8, 3, 3, 3, 4, 4), (-3, -5, -4), False),
((8, 1, 3, 4), (-2, -3, -1), False),
((8, 3, 3, 3, 3, 8), (-4, -5), True),
((8, 3, 2, 4, 8), (-4, -1, 2, -6), True),
((3, 2, 4, 8, 1, 1), (-4, -1, 2, -6, -5, -3), True),
((2, 4, 1, 8), (-4, -3, -1, 2, -6), True),
],
)
def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, ctx, kind):
Expand Down Expand Up @@ -2117,16 +2122,21 @@ def test_forward_npi_tanh(data_shape, dtype, target, ctx, kind):


@pytest.mark.skipif(not hasattr(mx.np, "where"), reason="mx.np.where hasn't been publish yet")
@pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (1, 8), (2, 2), (1, 3)])
@pytest.mark.parametrize(
"data_shape,cond_shape",
[[(2, 2, 2), (2, 2, 2)], [(2, 7, 2), (7, 2)], [(2, 2), (1, 2)], [(1, 3), (3, 3)]],
)
@pytest.mark.parametrize("data_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("cond_dtype", ["float64", "float32", "int64", "int32", "bool"])
@pytest.mark.parametrize("scalar", [1.0, 2.0])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, target, ctx, kind):
def test_forward_npi_where_rscalar(
data_shape, cond_shape, data_dtype, cond_dtype, scalar, target, ctx, kind
):
if data_dtype == "bool":
scalar = scalar == 0.0
cond_np = np.random.uniform(size=data_shape).astype(cond_dtype)
cond_np = np.random.uniform(size=cond_shape).astype(cond_dtype)
data_np = np.random.uniform(size=data_shape).astype(data_dtype)
cond = mx.sym.var("condition")
data = mx.sym.var("x")
Expand All @@ -2136,7 +2146,7 @@ def test_forward_npi_where_rscalar(data_shape, cond_dtype, data_dtype, scalar, t
dtypeDic["condition"] = cond_dtype
dtypeDic["x"] = data_dtype
mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"condition": data_shape, "x": data_shape}, dtype=dtypeDic
mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic
)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(cond_np, data_np)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def check_binary_op(opfunc, ref, dtype):
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01, atol=1e-3)

for opfunc, ref in [
(relay.add, np.add),
Expand Down

0 comments on commit 7217762

Please sign in to comment.