Skip to content

Commit

Permalink
[TOPI] Fix compiing batch_matmul and dense when two args are the same…
Browse files Browse the repository at this point in the history
… tensor (apache#9207)

* Add explicit copy stage for batch_matmul(x, x) case

* do copy in relay strategy to avoid dup

* add copy to dense op and schedules

* black

* add batch_matmul test

* add dense test

* fix cuda int8 dense test

* remove need_copy flag

* do not use tag to decide if tensors are same

* rename to copy_if_identical and add comment

* black

* one more fix missed

* add length check on input tensors

* one more length check

* fix variable name
  • Loading branch information
masahi authored and ylc committed Jan 7, 2022
1 parent 17c190d commit f8f3005
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 8 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,19 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
return strategy


def copy_if_identical(tensor_a, tensor_b):
"""
When two inputs to batch_matul or dense are the same tensor, e.g. batch_matmul(x, x),
compilation fails because TE thinks there is only one input tensor x, and doing
cache_read(x) on the same tensor twice results in an error.
To prevent such errors, we make the second tensor be the copy of the first one
when two input tensors are identical.
"""
if tensor_a == tensor_b:
return te.compute(tensor_a.shape, lambda *ind: tensor_a[ind])
return tensor_b


# matmul
def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap matmul topi compute"""
Expand All @@ -733,6 +746,7 @@ def _compute_matmul(attrs, inputs, out_type):
]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
args[1] = copy_if_identical(inputs[0], inputs[1])
return [topi_compute(*args)]

return _compute_matmul
Expand Down Expand Up @@ -762,6 +776,7 @@ def _compute_dense(attrs, inputs, out_type):
args = [inputs[0], inputs[1], None, out_dtype]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
args[1] = copy_if_identical(inputs[0], inputs[1])
return [topi_compute(*args)]

return _compute_dense
Expand Down Expand Up @@ -804,6 +819,7 @@ def _compute_batch_matmul(attrs, inputs, out_type):
args.append(attrs.transpose_b)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
args[1] = copy_if_identical(inputs[0], inputs[1])
return [topi_compute(*args)]

return _compute_batch_matmul
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def schedule_batch_matmul(cfg, outs):
def _schedule(cfg, op):
C = op.output(0)
A, B = s[C].op.input_tensors
if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A:
s[B].compute_inline()
_, M, N = get_const_tuple(C.shape)
AA = s.cache_read(A, "shared", [C])
AL = s.cache_read(AA, "local", [C])
Expand Down Expand Up @@ -336,6 +338,8 @@ def _callback(op):

def _schedule_batch_matmul_int8(cfg, s, output):
input_x, input_y = s[output].op.input_tensors
if len(input_y.op.input_tensors) == 1 and input_y.op.input_tensors[0] == input_x:
s[input_y].compute_inline()

B, M, K = get_const_tuple(input_x.shape)
_, N, _ = get_const_tuple(input_y.shape)
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/batch_matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def schedule_batch_matmul_tensorcore(cfg, outs):

def _schedule(cfg, s, C):
A, B = s[C].op.input_tensors
if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A:
s[B].compute_inline()
batch, m_dim, k_dim = get_const_tuple(A.shape)
batch, n_dim, k_dim = get_const_tuple(B.shape)
data_dtype = A.dtype
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def _callback(op):

def _schedule_dense_int8(cfg, s, output):
data, weight = s[output].op.input_tensors
if len(weight.op.input_tensors) == 1 and weight.op.input_tensors[0] == data:
s[weight].compute_inline()

batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/dense_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None):
def _schedule_dense_tensorcore(cfg, s, C):
"""Schedule dense operator using Tensorcore"""
A, B = s[C].op.input_tensors
if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A:
s[B].compute_inline()
batch, out_dim = get_const_tuple(C.shape)
data_dtype = A.dtype
out_dtype = C.dtype
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/topi/gpu/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def _callback(op):

def _schedule_dense_small_batch(cfg, s, C):
A, weights = C.op.input_tensors
if len(weights.op.input_tensors) == 1 and weights.op.input_tensors[0] == A:
s[weights].compute_inline()

_, in_dim_weights = get_const_tuple(weights.shape)
_, in_dim_A = get_const_tuple(A.shape)

Expand Down Expand Up @@ -141,6 +144,8 @@ def _callback(op):
def _schedule_dense_large_batch(cfg, s, C):
"""Schedule float32/64 dense with large batch size"""
A, B = C.op.input_tensors
if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A:
s[B].compute_inline()
batch, in_dim = get_const_tuple(A.shape)
out_dim, _ = get_const_tuple(B.shape)
k = C.op.reduce_axis[0]
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def _callback(op):
if "batch_matmul" in op.tag:
C = op.output(0)
A, B = op.input_tensors
if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A:
s[B].compute_inline()
_, M, K = get_const_tuple(A.shape)
_, _, N = get_const_tuple(C.shape)

Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,18 @@ def test_dense():
tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5)


@tvm.testing.uses_gpu
def test_dense_same_args_compile():
for dtype in ["float32", "int8"]:
x = relay.var("x", shape=(32, 64), dtype=dtype)
out_dtype = "int32" if dtype == "int8" else "float32"
f = relay.Function([x], relay.nn.dense(x, x, out_dtype=out_dtype))
m = tvm.IRModule.from_expr(f)

for target, _ in tvm.testing.enabled_targets():
tvm.relay.build(m, target=target)


def test_dense_dtype():
data_dtype = "uint8"
weight_dtype = "int8"
Expand Down
32 changes: 24 additions & 8 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,22 +325,34 @@ def verify_reverse_reshape(shape, newshape, oshape):
verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))


def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32", trans_x=False, trans_y=True):
x = relay.var("x", relay.TensorType(x_shape, dtype))
y = relay.var("y", relay.TensorType(y_shape, dtype))
def verify_batch_matmul_with_inputs(
x, y, x_np, y_np, out_shape, dtype="float32", trans_x=False, trans_y=True
):
z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y)
zz = run_infer_type(z)
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)

func = relay.Function([x, y], z)
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
input_vars = relay.analysis.free_vars(z)
func = relay.Function(input_vars, z)
z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
z = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_np, y_np)
tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5)
if len(input_vars) == 2:
z = relay.create_executor(kind, device=dev, target=target).evaluate(func)(
x_np, y_np
)
else:
z = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_np)
tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5, atol=1e-5)


def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32", trans_x=False, trans_y=True):
x = relay.var("x", relay.TensorType(x_shape, dtype))
y = relay.var("y", relay.TensorType(y_shape, dtype))
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
verify_batch_matmul_with_inputs(x, y, x_np, y_np, out_shape, dtype, trans_x, trans_y)


@tvm.testing.uses_gpu
Expand All @@ -360,6 +372,10 @@ def test_batch_matmul():
verify_batch_matmul((5, 16, 32), (5, 32, 16), (5, 16, 16), trans_x=False, trans_y=False)
verify_batch_matmul((5, 32, 16), (5, 32, 20), (5, 16, 20), trans_x=True, trans_y=False)

x_np = np.random.randn(10, 27, 64).astype("float32")
x = relay.var("x", shape=x_np.shape)
verify_batch_matmul_with_inputs(x, x, x_np, x_np, (10, 27, 27))


@tvm.testing.uses_gpu
def test_shape_of():
Expand Down

0 comments on commit f8f3005

Please sign in to comment.