Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] dense_tensorcore/batch_matmul_tensorcore support int8/int4 #8402

Merged
merged 13 commits into from
Jul 9, 2021
15 changes: 9 additions & 6 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,13 +844,16 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
x, y = inputs
_, M, K = get_const_tuple(x.shape)
_, N, K = get_const_tuple(y.shape)
if x.dtype in ["float16", "int8", "uint8"] and (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
if (
x.dtype in ["float16", "int8", "uint8"]
and (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
)
) or (x.dtype in ["int4", "uint4"] and K % 32 == 0 and M % 8 == 0 and N % 8 == 0):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore),
wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore),
name="batch_matmul_tensorcore.cuda",
plevel=20,
Expand Down
81 changes: 42 additions & 39 deletions python/tvm/topi/cuda/batch_matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@


@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None):
"""batch matmul tensorcore operator on cuda"""
# todo: deal with out_shape for broadcast, liuxin.ai
return batch_matmul_tensorcore_cuda(x, y)
return batch_matmul_tensorcore_cuda(x, y, out_dtype)


@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
Expand All @@ -57,10 +57,8 @@ def _schedule(cfg, s, C):
A, B = s[C].op.input_tensors
batch, m_dim, k_dim = get_const_tuple(A.shape)
batch, n_dim, k_dim = get_const_tuple(B.shape)
data_dtype = A.dtype
out_dtype = C.dtype
# inline astype fp16
s[A].compute_inline()
s[B].compute_inline()

# Explicit memory access
AS = s.cache_read(A, "shared", [C])
Expand Down Expand Up @@ -94,32 +92,35 @@ def _schedule(cfg, s, C):
cfg.define_knob("vec", [1, 2, 4, 8])

# Ensure that the default parameters are applicable when autotvm is not in use
if m_dim % 32 == 0 and n_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif m_dim % 16 == 0 and n_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif m_dim % 8 == 0 and n_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
if data_dtype in ["float16", "uint8", "int8"]:
if m_dim % 32 == 0 and n_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif m_dim % 16 == 0 and n_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif m_dim % 8 == 0 and n_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
wmma_k = 16
wmma_m = cfg["wmma_m"].val
if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
else:
wyc-ruiker marked this conversation as resolved.
Show resolved Hide resolved
wmma_m = wmma_n = 8
wmma_k = 32

warp_size = 32
wmma_k = 16
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
offsetCS = cfg["offsetCS"].val
wmma_m = cfg["wmma_m"].val
vec = cfg["vec"].val

if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8

# Define the stride of intrin functions
AS_align = chunk * wmma_k + offset
BS_align = chunk * wmma_k + offset
Expand Down Expand Up @@ -211,10 +212,8 @@ def shared_shedule(stage, strides):
shared_shedule(BS, BS_align)

shape = (wmma_m, wmma_n, wmma_k)
# TODO: add checking here, datatype casting may cause precision loss
in_dtype = "float16"
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
CL_compute = te.compute(
(wmma_m, wmma_n),
Expand All @@ -236,7 +235,7 @@ def shared_shedule(stage, strides):
"row_major",
(wmma_m, wmma_k),
(wmma_m, wmma_k),
"float16",
data_dtype,
),
)
s[BF].tensorize(
Expand All @@ -248,7 +247,7 @@ def shared_shedule(stage, strides):
"col_major",
(wmma_n, wmma_k),
(wmma_n, wmma_k),
"float16",
data_dtype,
),
)
s[CF].tensorize(
Expand All @@ -270,7 +269,7 @@ def _callback(op):
return s


def batch_matmul_tensorcore_cuda(x, y):
def batch_matmul_tensorcore_cuda(x, y, out_dtype=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Expand All @@ -294,22 +293,26 @@ def batch_matmul_tensorcore_cuda(x, y):
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
batch, M, K = x.shape
N = y.shape[1]
out_dtype = x.dtype

assert (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)"

x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16"))
y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16"))
if out_dtype is None:
out_dtype = x.dtype

assert x.dtype == y.dtype
assert x.dtype in ["float16", "uint8", "int8", "uint4", "int4"]
if x.dtype in ["float16", "uint8", "int8"]:
assert (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)"
else:
assert (
M % 8 == 0 and K % 32 == 0 and N % 8 == 0
), "The shape of (M, K, N) must be multiple of (8, 32, 8)"

k = te.reduce_axis((0, K), name="k")
return te.compute(
(batch, M, N),
lambda b, i, j: te.sum(
x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k
),
lambda b, i, j: te.sum(x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k),
tag="batch_matmul_tensorcore",
)
79 changes: 42 additions & 37 deletions python/tvm/topi/cuda/dense_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,27 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None):
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
assert (
(batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0)
or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0)
or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)
), (
"The shape of (batch, in_dim, out_dim) "
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
)

assert data.dtype == weight.dtype
assert data.dtype in ["float16", "int8", "uint8", "int4", "uint4"]
if data.dtype in ["float16", "int8", "uint8"]:
assert (
(batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0)
or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0)
or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)
), (
"The shape of (batch, in_dim, out_dim) "
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
)
else:
assert (
batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0
), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)"

k = te.reduce_axis((0, in_dim), name="k")
data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype("float16"))
weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype("float16"))
matmul = te.compute(
(batch, out_dim),
lambda i, j: te.sum(
data_16[i, k].astype(out_dtype) * weight_16[j, k].astype(out_dtype), axis=k
),
lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k),
name="T_dense",
tag="dense_tensorcore",
)
Expand All @@ -92,9 +97,8 @@ def _schedule_dense_tensorcore(cfg, s, C):
"""Schedule dense operator using Tensorcore"""
A, B = s[C].op.input_tensors
batch, out_dim = get_const_tuple(C.shape)
data_dtype = A.dtype
out_dtype = C.dtype
s[A].compute_inline()
s[B].compute_inline()

# Explicit memory access
AS = s.cache_read(A, "shared", [C])
Expand Down Expand Up @@ -127,33 +131,36 @@ def _schedule_dense_tensorcore(cfg, s, C):
cfg.define_knob("offsetCS", [0, 8])
cfg.define_knob("vec", [1, 2, 4, 8])

# Ensure that the default parameters are applicable when autotvm is not in use
if batch % 32 == 0 and out_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif batch % 16 == 0 and out_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif batch % 8 == 0 and out_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
if data_dtype in ["float16", "int8", "uint8"]:
# Ensure that the default parameters are applicable when autotvm is not in use
if batch % 32 == 0 and out_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif batch % 16 == 0 and out_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif batch % 8 == 0 and out_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
wmma_k = 16
wmma_m = cfg["wmma_m"].val
if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
else:
wyc-ruiker marked this conversation as resolved.
Show resolved Hide resolved
wmma_m = wmma_n = 8
wmma_k = 32

warp_size = 32
wmma_k = 16
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
offsetCS = cfg["offsetCS"].val
wmma_m = cfg["wmma_m"].val
vec = cfg["vec"].val

if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8

# Define the stride of intrin functions
AS_align = chunk * wmma_k + offset
BS_align = chunk * wmma_k + offset
Expand Down Expand Up @@ -245,10 +252,8 @@ def shared_shedule(stage, strides):
shared_shedule(BS, BS_align)

shape = (wmma_m, wmma_n, wmma_k)
# TODO: add checking here, datatype casting may cause precision loss
in_dtype = "float16"
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
CL_compute = te.compute(
(wmma_m, wmma_n),
Expand All @@ -264,13 +269,13 @@ def shared_shedule(stage, strides):
s[AF].tensorize(
b_ii,
intrin_wmma_load_matrix_A(
AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16"
AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), data_dtype
),
)
s[BF].tensorize(
o_ii,
intrin_wmma_load_matrix_W(
BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), "float16"
BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), data_dtype
),
)
s[CF].tensorize(
Expand Down
Loading