From 98449395040bcaf6ea24fd6f11dfb42096a4b772 Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Thu, 29 Oct 2020 10:36:58 +0800 Subject: [PATCH 01/12] add batch_matmul_tensorcore --- python/tvm/relay/op/strategy/cuda.py | 14 + python/tvm/topi/cuda/__init__.py | 1 + .../tvm/topi/cuda/batch_matmul_tensorcore.py | 274 ++++++++++++++++++ 3 files changed, 289 insertions(+) create mode 100644 python/tvm/topi/cuda/batch_matmul_tensorcore.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1229a71569d0..d685986c048b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -657,6 +657,20 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul_cublas.cuda", plevel=15, ) + if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version): + x, y = inputs + B, M, K = get_const_tuple(x.shape) + B, N, K = get_const_tuple(y.shape) + # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + if ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \ + (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \ + (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)): + strategy.add_implementation( + wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore), + wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore), + name="batch_matmul_tensorcore.cuda", + plevel=20) + return strategy diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 3ff544f4bb3e..9f2dcd66cfaa 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -42,6 +42,7 @@ from .pooling import * from .nn import schedule_lrn from .batch_matmul import * +from .batch_matmul_tensorcore import * from .vision import * from .ssd import * from .nms import get_valid_counts, non_max_suppression diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py new file mode 100644 index 000000000000..fdb445449028 --- /dev/null +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument +"""cuda batch_matmul operators""" +import tvm +from tvm import autotvm +from tvm import te +from ..util import traverse_inline, get_const_tuple +from .tensor_intrin import intrin_wmma_load_matrix_A, \ + intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm + +@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") +def batch_matmul_tensorcore(cfg, x, y): + """batch matmul tensorcore operator on cuda""" + return batch_matmul_tensorcore_cuda(x, y) + + +@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda") +def schedule_batch_matmul_tensorcore(cfg, outs): + """Schedule for batch_matmul operator using Tensorcore + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of batch_matmul + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _schedule(cfg, s, C): + A, B = s[C].op.input_tensors + batch, m_dim, k_dim = get_const_tuple(A.shape) + batch, n_dim, k_dim = get_const_tuple(B.shape) + out_dtype = C.dtype + # inline astype fp16 + s[A].compute_inline() + s[B].compute_inline() + + # Explicit memory access + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') + CS = s.cache_read(CF, 'shared', [C]) + + # fallback support + target = tvm.target.Target.current() + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.kind.name, target.model, 'batch_matmul_tensorcore.cuda') + cfg.fallback_with_reference_log(ref_log) + + # ??? Deal with op fusion, such as bias and relu ??? is this needed? + # Deal with slice after padding + if C.op not in s.outputs and "injective" in s.outputs[0].tag: + s[C].compute_inline() + C = s.outputs[0].output(0) + + # create tuning space + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4]) + cfg.define_knob("warp_col_tiles", [1, 2, 4]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("offset", [0, 8]) + cfg.define_knob("offsetCS", [0, 8]) + cfg.define_knob("vec", [1, 2, 4, 8]) + + # Ensure that the default parameters are applicable when autotvm is not in use + if (m_dim % 32 == 0 and n_dim % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + elif (m_dim % 16 == 0 and n_dim % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (m_dim % 8 == 0 and n_dim % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + + warp_size = 32 + wmma_k = 16 + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + offset = cfg["offset"].val + offsetCS = cfg["offsetCS"].val + wmma_m = cfg["wmma_m"].val + vec = cfg["vec"].val + + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + # Define the stride of intrin functions + AS_align = chunk * wmma_k + offset + BS_align = chunk * wmma_k + offset + CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS + AS_stride = [AS_align, 1] + BS_stride = [BS_align, 1] + AF_stride = [wmma_k, 1] + BF_stride = [wmma_k, 1] + CF_stride = [warp_col_tiles * wmma_n, 1] + CS_stride = [CS_align, 1] + + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Schedule for dense computation + block_factor_m = wmma_m * warp_row_tiles * block_row_warps + block_factor_n = wmma_n * warp_col_tiles * block_col_warps + b, m, n = C.op.axis + block_i, bc = s[C].split(m, factor=block_factor_m) + block_j, oc = s[C].split(n, factor=block_factor_n) + s[C].reorder(b, block_i, block_j, bc, oc) + t = s[C].fuse(bc, oc) + t, vi = s[C].split(t, factor=vec) + t, tx = s[C].split(t, factor=warp_size) + t, ty = s[C].split(t, factor=block_row_warps) + t, tz = s[C].split(t, factor=block_col_warps) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(b, block_z) + s[C].bind(tz, thread_z) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + s[C].vectorize(vi) + + # Schedule for wmma store + s[CS].compute_at(s[C], block_j) + bs, bb, oo = CS.op.axis + s[CS].storage_align(bb, CS_align - 1, CS_align) + bb, bbi = s[CS].split(bb, factor=wmma_m) + oo, ooi = s[CS].split(oo, factor=wmma_n) + bb, bbii = s[CS].split(bb, factor=warp_row_tiles) + oo, ooii = s[CS].split(oo, factor=warp_col_tiles) + s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi) + + # Schedule for wmma computation + s[CF].compute_at(s[CS], oo) + bs, warp_i, warp_j = CF.op.axis + warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) + warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) + k, = CF.op.reduce_axis + k, _k = s[CF].split(k, factor=wmma_k) + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k) + + # Schedule for wmma_matrix_a load + s[AF].compute_at(s[CF], ki) + bs, b, i = AF.op.axis + b, b_ii = s[AF].split(b, factor=wmma_m) + i, i_jj = s[AF].split(i, factor=wmma_k) + s[AF].reorder(bs, b, i, b_ii, i_jj) + + # Schedule for wmma_matrix_b load + s[BF].compute_at(s[CF], ki) + bs, o, i = BF.op.axis + o, o_ii = s[BF].split(o, factor=wmma_n) + i, i_ii = s[BF].split(i, factor=wmma_k) + s[BF].reorder(bs, o, i, o_ii, i_ii) + + # Schedule for A's(B's) shared memory load + def shared_shedule(stage, strides): + s[stage].compute_at(s[CF], ko) + bs, xo, yo = stage.op.axis + s[stage].storage_align(xo, strides - 1, strides) + t = s[stage].fuse(xo, yo) + t, vi = s[stage].split(t, factor=vec) + t, tx = s[stage].split(t, factor=warp_size) + t, ty = s[stage].split(t, factor=block_row_warps) + _, tz = s[stage].split(t, factor=block_col_warps) + s[stage].bind(ty, thread_y) + s[stage].bind(tz, thread_z) + s[stage].bind(tx, thread_x) + s[stage].vectorize(vi) + + shared_shedule(AS, AS_align) + shared_shedule(BS, BS_align) + + shape = (wmma_m, wmma_n, wmma_k) + in_dtype = 'float16' + AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name='BL_gemm', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm') + CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj: + te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \ + BL_gemm[jj, k_gemm].astype(out_dtype), \ + axis=k_gemm), name='CL_compute') + + # lower the computation loops down to TensorCore hardware intrinsics + # by mapping the dense tensorcore to tensor intrinsics + s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A( \ + AF_stride, AS_stride, shape, "row_major", \ + (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16')) + s[BF].tensorize(o_ii, intrin_wmma_load_matrix_W( \ + BF_stride, BS_stride, shape, "col_major", \ + (wmma_n, wmma_k), (wmma_n, wmma_k), 'float16')) + s[CF].tensorize(_ii, intrin_wmma_gemm( \ + AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape)) + s[CS].tensorize(bbi, intrin_wmma_store_matrix( \ + CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n))) + + def _callback(op): + if "batch_matmul_tensorcore" in op.tag: + _schedule(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + +def batch_matmul_tensorcore_cuda(x, y): + """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are + data in batch. + + Parameters + ---------- + x : tvm.te.Tensor + 3-D with shape [batch, M, K] + + y : tvm.te.Tensor + 3-D with shape [batch, N, K] + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + batch, M, K = x.shape + N = y.shape[1] + out_dtype = x.dtype + + assert ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \ + (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \ + (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)), \ + "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + + x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype('float16')) + y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype('float16')) + + k = te.reduce_axis((0, K), name='k') + return te.compute((batch, M, N), + lambda b, i, j: te.sum(x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k), + tag='batch_matmul_tensorcore') From 1ea5ea84c46d4ed392570ed0cec850bbe949dc33 Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Thu, 29 Oct 2020 11:02:09 +0800 Subject: [PATCH 02/12] add bmm cublas autotune --- python/tvm/topi/cuda/batch_matmul.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 8d34b2996593..006b866d6bad 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -21,7 +21,7 @@ from tvm import te from tvm.contrib import cublas from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from .. import nn +from .. import nn, generic from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor @@ -138,7 +138,8 @@ def _callback(op): return s -def batch_matmul_cublas(x, y, out_shape=None): +@autotvm.register_topi_compute("batch_matmul_cublas.cuda") +def batch_matmul_cublas(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -158,4 +159,13 @@ def batch_matmul_cublas(x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + b, m, k = x.shape + b, n, k = y.shape + cfg.add_flop(b * m * k * n * 2) return cublas.batch_matmul(x, y, False, True) + + +@autotvm.register_topi_schedule("batch_matmul_cublas.cuda") +def schedule_batch_matmul_cublas(_, outs): + """Schedule batch_matmul operator using CUBLAS""" + return generic.schedule_extern(outs) From debec15e09e26209243832a56d9fe03f71c4ad2b Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Mon, 9 Nov 2020 12:38:07 +0800 Subject: [PATCH 03/12] add bmm tests --- .../tvm/topi/cuda/batch_matmul_tensorcore.py | 2 +- .../test_topi_batch_matmul_tensorcore.py | 75 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 tests/python/topi/python/test_topi_batch_matmul_tensorcore.py diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index fdb445449028..10efb37d3117 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm from tvm import te -from ..util import traverse_inline, get_const_tuple +from ..utils import traverse_inline, get_const_tuple from .tensor_intrin import intrin_wmma_load_matrix_A, \ intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py new file mode 100644 index 000000000000..60f4bef3a855 --- /dev/null +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for batch_matmul operator""" +import numpy as np +import tvm +from tvm import te +from tvm import topi +import tvm.topi.testing +from tvm.topi.utils import get_const_tuple +from tvm.contrib.pickle_memoize import memoize + +import tvm.testing + +_batch_matmul_implement = { + "gpu": (topi.cuda.batch_matmul_tensorcore, topi.cuda.schedule_batch_matmul_tensorcore), +} + + +def verify_batch_matmul(x_batch, y_batch, M, N, K): + x = te.placeholder((x_batch, M, K), name="x") + y = te.placeholder((y_batch, N, K), name="y") + dtype = x.dtype + + # use memoize to pickle the test data for next time use + @memoize("topi.tests.test_topi_batch_matmul_tensorcore") + def get_ref_data(): + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np) + return (a_np, b_np, c_np) + + # get the test data + a_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + print("Running on target: %s" % device) + with tvm.target.Target(device): + fcompute, fschedule = tvm.topi.testing.dispatch(device, _batch_matmul_implement) + out = fcompute(x, y) + s = fschedule([out]) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), ctx) + f = tvm.build(s, [x, y, out], device, name="dense") + f(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) + + check_device("cuda") + + +@tvm.testing.uses_gpu +def test_batch_matmul(): + verify_batch_matmul(1, 1, 16, 16, 32) + verify_batch_matmul(5, 5, 16, 16, 32) + verify_batch_matmul(5, 5, 16, 32, 32) + verify_batch_matmul(30, 30, 16, 32, 32) + + +if __name__ == "__main__": + test_batch_matmul() From a1808b4bd78af6d74bcf9a669a9c013199c07557 Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Mon, 9 Nov 2020 21:05:43 +0800 Subject: [PATCH 04/12] out_shape for bmm_tensorcore --- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 10efb37d3117..a197e6338e50 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -24,8 +24,9 @@ intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y): +def batch_matmul_tensorcore(cfg, x, y, out_shape=None): """batch matmul tensorcore operator on cuda""" + # todo: deal with out_shape for broadcast, liuxin.ai return batch_matmul_tensorcore_cuda(x, y) From a2eee2ac97a1c986f81bb0b0e58e4230d06630ec Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Tue, 22 Dec 2020 14:05:41 +0800 Subject: [PATCH 05/12] fix comments --- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index a197e6338e50..b3025871b580 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -72,8 +72,7 @@ def _schedule(cfg, s, C): target.kind.name, target.model, 'batch_matmul_tensorcore.cuda') cfg.fallback_with_reference_log(ref_log) - # ??? Deal with op fusion, such as bias and relu ??? is this needed? - # Deal with slice after padding + # Deal with op fusion, such as bias/relu and slice after padding if C.op not in s.outputs and "injective" in s.outputs[0].tag: s[C].compute_inline() C = s.outputs[0].output(0) From bf8936a884eed561d5c92a32b6ec0c62dcba778d Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Wed, 30 Dec 2020 16:59:50 +0800 Subject: [PATCH 06/12] code format --- python/tvm/relay/op/strategy/cuda.py | 11 +- .../tvm/topi/cuda/batch_matmul_tensorcore.py | 138 +++++++++++------- 2 files changed, 96 insertions(+), 53 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index d685986c048b..db894d9d87fa 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -662,14 +662,17 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): B, M, K = get_const_tuple(x.shape) B, N, K = get_const_tuple(y.shape) # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" - if ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \ - (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \ - (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)): + if ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ): strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore), name="batch_matmul_tensorcore.cuda", - plevel=20) + plevel=20, + ) return strategy diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index b3025871b580..30c3e892b091 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -20,8 +20,13 @@ from tvm import autotvm from tvm import te from ..utils import traverse_inline, get_const_tuple -from .tensor_intrin import intrin_wmma_load_matrix_A, \ - intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm +from .tensor_intrin import ( + intrin_wmma_load_matrix_A, + intrin_wmma_load_matrix_W, + intrin_wmma_store_matrix, + intrin_wmma_gemm, +) + @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") def batch_matmul_tensorcore(cfg, x, y, out_shape=None): @@ -58,18 +63,19 @@ def _schedule(cfg, s, C): s[B].compute_inline() # Explicit memory access - AS = s.cache_read(A, 'shared', [C]) - BS = s.cache_read(B, 'shared', [C]) - AF = s.cache_read(AS, 'wmma.matrix_a', [C]) - BF = s.cache_read(BS, 'wmma.matrix_b', [C]) - CF = s.cache_write(C, 'wmma.accumulator') - CS = s.cache_read(CF, 'shared', [C]) + AS = s.cache_read(A, "shared", [C]) + BS = s.cache_read(B, "shared", [C]) + AF = s.cache_read(AS, "wmma.matrix_a", [C]) + BF = s.cache_read(BS, "wmma.matrix_b", [C]) + CF = s.cache_write(C, "wmma.accumulator") + CS = s.cache_read(CF, "shared", [C]) # fallback support target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.kind.name, target.model, 'batch_matmul_tensorcore.cuda') + target.kind.name, target.model, "batch_matmul_tensorcore.cuda" + ) cfg.fallback_with_reference_log(ref_log) # Deal with op fusion, such as bias/relu and slice after padding @@ -88,11 +94,11 @@ def _schedule(cfg, s, C): cfg.define_knob("vec", [1, 2, 4, 8]) # Ensure that the default parameters are applicable when autotvm is not in use - if (m_dim % 32 == 0 and n_dim % 8 == 0): + if m_dim % 32 == 0 and n_dim % 8 == 0: cfg.define_knob("wmma_m", [32, 16, 8]) - elif (m_dim % 16 == 0 and n_dim % 16 == 0): + elif m_dim % 16 == 0 and n_dim % 16 == 0: cfg.define_knob("wmma_m", [16, 8, 32]) - elif (m_dim % 8 == 0 and n_dim % 32 == 0): + elif m_dim % 8 == 0 and n_dim % 32 == 0: cfg.define_knob("wmma_m", [8, 16, 32]) warp_size = 32 @@ -125,12 +131,12 @@ def _schedule(cfg, s, C): CF_stride = [warp_col_tiles * wmma_n, 1] CS_stride = [CS_align, 1] - block_x = te.thread_axis('blockIdx.x') - block_y = te.thread_axis('blockIdx.y') - block_z = te.thread_axis('blockIdx.z') - thread_x = te.thread_axis('threadIdx.x') - thread_y = te.thread_axis('threadIdx.y') - thread_z = te.thread_axis('threadIdx.z') + block_x = te.thread_axis("blockIdx.x") + block_y = te.thread_axis("blockIdx.y") + block_z = te.thread_axis("blockIdx.z") + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") + thread_z = te.thread_axis("threadIdx.z") # Schedule for dense computation block_factor_m = wmma_m * warp_row_tiles * block_row_warps @@ -167,7 +173,7 @@ def _schedule(cfg, s, C): bs, warp_i, warp_j = CF.op.axis warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) - k, = CF.op.reduce_axis + (k,) = CF.op.reduce_axis k, _k = s[CF].split(k, factor=wmma_k) ko, ki = s[CF].split(k, factor=chunk) s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k) @@ -205,27 +211,55 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - in_dtype = 'float16' - AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name='BL_gemm', dtype=in_dtype) - k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm') - CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj: - te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \ - BL_gemm[jj, k_gemm].astype(out_dtype), \ - axis=k_gemm), name='CL_compute') + in_dtype = "float16" + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") + CL_compute = te.compute( + (wmma_m, wmma_n), + lambda ii, jj: te.sum( + AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj, k_gemm].astype(out_dtype), + axis=k_gemm, + ), + name="CL_compute", + ) # lower the computation loops down to TensorCore hardware intrinsics # by mapping the dense tensorcore to tensor intrinsics - s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A( \ - AF_stride, AS_stride, shape, "row_major", \ - (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16')) - s[BF].tensorize(o_ii, intrin_wmma_load_matrix_W( \ - BF_stride, BS_stride, shape, "col_major", \ - (wmma_n, wmma_k), (wmma_n, wmma_k), 'float16')) - s[CF].tensorize(_ii, intrin_wmma_gemm( \ - AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape)) - s[CS].tensorize(bbi, intrin_wmma_store_matrix( \ - CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n))) + s[AF].tensorize( + b_ii, + intrin_wmma_load_matrix_A( + AF_stride, + AS_stride, + shape, + "row_major", + (wmma_m, wmma_k), + (wmma_m, wmma_k), + "float16", + ), + ) + s[BF].tensorize( + o_ii, + intrin_wmma_load_matrix_W( + BF_stride, + BS_stride, + shape, + "col_major", + (wmma_n, wmma_k), + (wmma_n, wmma_k), + "float16", + ), + ) + s[CF].tensorize( + _ii, + intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape), + ) + s[CS].tensorize( + bbi, + intrin_wmma_store_matrix( + CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n) + ), + ) def _callback(op): if "batch_matmul_tensorcore" in op.tag: @@ -234,6 +268,7 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + def batch_matmul_tensorcore_cuda(x, y): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -260,15 +295,20 @@ def batch_matmul_tensorcore_cuda(x, y): N = y.shape[1] out_dtype = x.dtype - assert ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \ - (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \ - (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)), \ - "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" - - x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype('float16')) - y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype('float16')) - - k = te.reduce_axis((0, K), name='k') - return te.compute((batch, M, N), - lambda b, i, j: te.sum(x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k), - tag='batch_matmul_tensorcore') + assert ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + + x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16")) + y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16")) + + k = te.reduce_axis((0, K), name="k") + return te.compute( + (batch, M, N), + lambda b, i, j: te.sum( + x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k + ), + tag="batch_matmul_tensorcore", + ) From 8ce932c5ae22f211f2b7e8af8cf77454a1a4d553 Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Wed, 30 Dec 2020 17:06:37 +0800 Subject: [PATCH 07/12] add todos for tensorcore datatype checking --- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 1 + python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py | 1 + python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py | 1 + python/tvm/topi/cuda/dense_tensorcore.py | 1 + 4 files changed, 4 insertions(+) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 30c3e892b091..ea6f6860605b 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -211,6 +211,7 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) + # TODO: add checking here, datatype casting may cause precision loss in_dtype = "float16" AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) diff --git a/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py b/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py index f665cc779dc5..76f082f07b44 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py @@ -72,6 +72,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # convert data type of input feature maps and weights + # TODO: add checking here, datatype casting may cause precision loss TransPaddedInput = te.compute( PaddedInput.shape, lambda n, h, w, c: PaddedInput[n, h, w, c].astype("float16") ) diff --git a/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py b/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py index a5c4e81a4dc3..efb25744b802 100644 --- a/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py +++ b/python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py @@ -75,6 +75,7 @@ def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dty ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # convert data type of input feature maps and weights + # TODO: add checking here, datatype casting may cause precision loss TransPaddedInput = te.compute( PaddedInput.shape, lambda n, d, h, w, c: PaddedInput[n, d, h, w, c].astype("float16") ) diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index a59ebd7347bb..430f8044528c 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -245,6 +245,7 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) + # TODO: add checking here, datatype casting may cause precision loss in_dtype = "float16" AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) From 595b9f1313b8b453cb59f6da33ec471e2c68edaf Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Thu, 31 Dec 2020 11:19:15 +0800 Subject: [PATCH 08/12] fix lint --- python/tvm/relay/op/strategy/cuda.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index db894d9d87fa..b211ca144bf8 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -659,9 +659,8 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): ) if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version): x, y = inputs - B, M, K = get_const_tuple(x.shape) - B, N, K = get_const_tuple(y.shape) - # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + _, M, K = get_const_tuple(x.shape) + _, N, K = get_const_tuple(y.shape) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) From 3da829919fc073cf9077ecf7d8e1ab8cb7c3eead Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Thu, 31 Dec 2020 11:33:11 +0800 Subject: [PATCH 09/12] fix lint --- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index ea6f6860605b..59b92ec9e623 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -300,7 +300,7 @@ def batch_matmul_tensorcore_cuda(x, y): (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) - ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16")) y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16")) From cdca880295ee1b17a80360a15ac3378c6b36fd9a Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Thu, 31 Dec 2020 14:28:15 +0800 Subject: [PATCH 10/12] fix have_tensorcore --- python/tvm/relay/op/strategy/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b211ca144bf8..97c1a71ed89d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -657,7 +657,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul_cublas.cuda", plevel=15, ) - if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version): + if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) From 72a5885411a2d6121b0750603893bbc38332e1c2 Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Fri, 8 Jan 2021 10:21:59 +0800 Subject: [PATCH 11/12] add dtype check for batch_matmul_tensorcore --- python/tvm/relay/op/strategy/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 97c1a71ed89d..3df094f02cb9 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -661,7 +661,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) - if ( + if data.dtype in ["float16", "int8", "uint8"] and ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) From 618b4bfc854c4c778173fc42aa4ea18e374d220d Mon Sep 17 00:00:00 2001 From: "liuxin.ai" Date: Fri, 8 Jan 2021 10:40:21 +0800 Subject: [PATCH 12/12] fix --- python/tvm/relay/op/strategy/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 3df094f02cb9..254e23e66abc 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -661,7 +661,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) - if data.dtype in ["float16", "int8", "uint8"] and ( + if x.dtype in ["float16", "int8", "uint8"] and ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)