diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b2a0ff466bc51..4b50937fc8387 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -172,6 +172,25 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore), name="conv2d_nhwc_tensorcore.cuda", plevel=20) + elif layout == "HWNC": + assert kernel_layout in ["HWOI", "HWOI16o16i", "HWOI8o32i", "HWOI32o16i"] + _, _, N, in_channels = get_const_tuple(data.shape) + pre_computed = len(kernel.shape) == 6 + if pre_computed: + _, _, oc_chunk, _, oc_block_factor, _ = get_const_tuple(kernel.shape) + out_channels = oc_chunk * oc_block_factor + else: + _, _, out_channels, _ = get_const_tuple(kernel.shape) + if topi.cuda.is_shape_tensorcore_direct_qualified( + batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype): + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), + wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore), + name="conv2d_hwnc_tensorcore_direct.cuda", + plevel=20) + else: + raise RuntimeError("Unsupported shape for conv2d HWNC.\ + Need to satisfy tensor core schedule.") elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 90f4e6074ffcc..ed8037024635f 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -50,5 +50,6 @@ from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * +from .conv2d_hwnc_tensorcore import * from .correlation import * from .sparse import * diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index c2a19054434e6..f07ef984025f9 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -171,6 +171,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): dispatch_ctx.update(target, new_workload, cfg) return relay.nn.conv2d(*inputs, **new_attrs) + if topi_tmpl == "conv2d_HWNCnc_tensorcore.cuda": + assert data_layout == "HWNC" and kernel_layout == "HWOI" + assert float(tvm.gpu(0).compute_version) >= 7.5 + H, W, N, CI = get_const_tuple(data.shape) + KH, KW, CO, _ = get_const_tuple(kernel.shape) + + if kernel.dtype in ['int4', 'uint4'] and (CI % 32 != 0 or CO % 8 != 0) or \ + kernel.dtype in ['int8', 'uint8'] and (CI % 16 != 0 or CO % 32 != 0): + return relay.nn.conv2d(*inputs, **new_attrs) + + new_attrs["channels"] = CO + if kernel.dtype in ['int4', 'uint4']: + new_attrs['kernel_layout'] = 'HWOI8o32i' + ic_block_factor = 32 + oc_block_factor = 8 + else: + new_attrs['kernel_layout'] = 'HWOI32o16i' + ic_block_factor = 16 + oc_block_factor = 32 + + new_kernel = te.placeholder((KH, KW, CO // oc_block_factor, CI // ic_block_factor, + oc_block_factor, ic_block_factor), dtype=kernel.dtype) + + new_workload = autotvm.task.args_to_workload( + [data, new_kernel, strides, padding, dilation, out_dtype], + "conv2d_HWNCnc_tensorcore.cuda") + + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.conv2d(*inputs, **new_attrs) + return None @conv2d_legalize.register("cuda") diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py new file mode 100644 index 0000000000000..592613ffcf920 --- /dev/null +++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-function-args +# pylint: disable=too-many-statements, unused-argument, too-many-arguments +"""Tensorcore template for cuda backend""" +import tvm +from tvm import te +from tvm import autotvm +from tvm.topi.cuda.injective import schedule_injective_from_existing +from ..util import get_const_tuple, traverse_inline, simplify, tag +from ..nn.pad import pad +from ..nn.util import get_pad_tuple +from .tensor_intrin import intrin_wmma_load_matrix_A +from .tensor_intrin import intrin_wmma_load_matrix_W +from .tensor_intrin import intrin_wmma_store_matrix +from .tensor_intrin import intrin_wmma_gemm + + +def unpack_HWNCnc_to_hwnc(packed_out, out_dtype): + """Unpack conv2d_hwnc output from layout hwncnc to hwnc + + Parameters + ----------- + packed_out : tvm.te.Tensor + The output tensor of conv2d_hwnc. + + out_dtype : str + The output dtype. + + Returns + ------- + unpacked_out : tvm.te.Tensor + The unpacked output tensor in hwnc layout. + """ + H, W, N, O, wmma_m, wmma_n = get_const_tuple(packed_out.shape) + + idxmod = tvm.tir.indexmod + idxdiv = tvm.tir.indexdiv + + oshape = (H, W, N * wmma_m, O * wmma_n) + unpacked_out = \ + te.compute(oshape, + lambda h, w, n, o: + packed_out[h, w, idxdiv(n, wmma_m), idxdiv(o, wmma_n), + idxmod(n, wmma_m), idxmod(o, wmma_n)] + .astype(out_dtype), + name='output_unpack', + tag=tag.INJECTIVE + ",unpack_hwncc") + return unpacked_out + + +def conv2d_hwnc_tensorcore(data, kernel, strides, padding, dilation, in_dtype, out_dtype='int32'): + """"Compute conv2d with tensorcore for HWNC layout with int8/int4""" + assert data.dtype in ('int4', 'uint4', 'int8', 'uint8') + assert kernel.dtype in ('int4', 'uint4', 'int8', 'uint8') + packed_out = hwnc_tensorcore_cuda( + data, kernel, strides, padding, dilation, out_dtype) + return unpack_HWNCnc_to_hwnc(packed_out, out_dtype) + + +@autotvm.register_topi_compute("conv2d_HWNCnc_tensorcore.cuda") +def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype='int32'): + """Compute declaration for tensorcore""" + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + in_dtype = Input.dtype + + if in_dtype in ['int4', 'uint4']: + wmma_n = wmma_m = 8 + wmma_k = 32 + else: + wmma_m = 8 + wmma_n = 32 + wmma_k = 16 + + pre_computed = len(Filter.shape) == 6 + in_height, in_width, batch, in_channels = get_const_tuple(Input.shape) + if pre_computed: + kernel_h, kernel_w, oc_chunk, _, oc_block_factor, _\ + = get_const_tuple(Filter.shape) + num_filter = oc_block_factor * oc_chunk + else: + kernel_h, kernel_w, num_filter, _ = get_const_tuple(Filter.shape) + + if in_dtype in ['int4', 'uint4']: + assert (batch % 8 == 0 and in_channels % + 32 == 0 and num_filter % 8 == 0) + else: + assert (batch % 8 == 0 and in_channels % 16 == 0 and num_filter % 32 == 0), \ + "The shape of (batch, in_channels, num_filter) "\ + "must be multiple of (8, 16, 32) for int8, "\ + "and (8, 32, 8) for int4" + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + + out_channels = num_filter + out_height = simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + + pad_left + pad_right) // stride_w + 1) + + cfg.add_flop(2 * batch * out_height * out_width * + out_channels * in_channels * kernel_h * kernel_w) + + # Input feature map: (H, W, N, IC, n, ic) + data_shape = (in_height, + in_width, + batch // wmma_m, + in_channels // wmma_k, + wmma_m, + wmma_k) + + # Kernel: (H, W, OC, IC, oc, ic) + kernel_shape = (kernel_h, + kernel_w, + out_channels // wmma_n, + in_channels // wmma_k, + wmma_n, + wmma_k) + + # Reduction axes + kh = te.reduce_axis((0, kernel_h), name='kh') + kw = te.reduce_axis((0, kernel_w), name='kw') + ic = te.reduce_axis((0, in_channels // wmma_k), name='ic') + ii = te.reduce_axis((0, wmma_k), name='ii') + + if pre_computed: + packed_kernel = Filter + else: + packed_kernel = te.compute(kernel_shape, lambda kh, kw, o, i, oo, ii: + Filter[kh, kw, o * wmma_n + oo, i * wmma_k + ii], + name="packed_kernel" + ) + + packed_data = te.compute(data_shape, + lambda h, w, n, i, nn, ii: Input[h, + w, n * wmma_m + nn, i * wmma_k + ii] + ) + + pad_before = [pad_top, pad_left, 0, 0, 0, 0] + pad_after = [pad_down, pad_right, 0, 0, 0, 0] + pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") + + Conv = te.compute((out_height, out_width, batch // wmma_m, + out_channels // wmma_n, wmma_m, wmma_n), + lambda h, w, n, o, nn, oo: te.sum( + (pad_data[h * stride_h + kh, w * stride_w + kw, + n, ic, nn, ii].astype('int32') * + packed_kernel[kh, kw, o, ic, oo, ii].astype('int32')), + axis=[ic, kh, kw, ii]), + name="Conv", tag="conv2d_HWNCnc_tensorcore") + return Conv + + +def schedule_hwnc_tensorcore_cuda(cfg, s, Conv): + """Schedule tensorcore template""" + packed_data, packed_kernel = s[Conv].op.input_tensors + ic, kh, kw, ii = s[Conv].op.reduce_axis + pad_data = s[packed_data].op.input_tensors[0] + + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Designate the memory hierarchy + AS = s.cache_read(packed_data, 'shared', [Conv]) + WS = s.cache_read(packed_kernel, 'shared', [Conv]) + AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) + WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) + ConvF = s.cache_write(Conv, 'wmma.accumulator') + + if Conv.op in s.outputs: + output = Conv + ConvS = s.cache_read(ConvF, 'shared', [Conv]) + OL = ConvS + else: + output = s.outputs[0].output(0) + s[Conv].set_scope('shared') + OL = Conv + + out_dtype = Conv.dtype + + if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel": + if autotvm.GLOBAL_SCOPE.in_tuning: + s[packed_kernel].pragma( + s[packed_kernel].op.axis[0], "debug_skip_region") + else: + with tvm.target.create('cuda'): + schedule_injective_from_existing(s, packed_kernel) + + if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in pad_data.op.tag: + s[pad_data].compute_inline() + data = pad_data.op.input_tensors[0] + + if autotvm.GLOBAL_SCOPE.in_tuning: + # skip this part during tuning to make recrods accurate + # this part will be pre-computed during NNVM's pre-compute optimization pass + s[pad_data].pragma(s[pad_data].op.axis[0], "debug_skip_region") + else: + data = pad_data + s[data].compute_inline() + + data_dtype = data.dtype + kernel_dtype = packed_kernel.dtype + + # Schedule for autotvm + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("fuse_pack", [0, 1]) + cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32]) + cfg.define_knob("vector_ws", [1, 8]) + cfg.define_knob("vector_as", [1, 8, 16]) + + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + vector_as = cfg["vector_as"].val + vector_ws = cfg["vector_ws"].val + split_block_k_nums = cfg["split_block_k_nums"].val + fuse_pack = cfg["fuse_pack"].val + + if not fuse_pack: + s[packed_data].compute_inline() + else: + with tvm.target.create('cuda'): + schedule_injective_from_existing(s, packed_data) + + if data_dtype in ['int4', 'uint4']: + wmma_m = wmma_n = 8 + wmma_k = 32 + else: + wmma_m = 8 + wmma_n = 32 + wmma_k = 16 + + warp_size = 32 + + # Schedule for output + if len(s[output].op.axis) == 4: + hc, wc, nc, oc, = output.op.axis + nc, nnc = s[output].split(nc, factor=wmma_m) + oc, ooc = s[output].split(oc, factor=wmma_n) + else: + hc, wc, nc, oc, nnc, ooc = output.op.axis + + kernel_scope, hc = s[output].split(hc, nparts=1) + + block_k = s[output].fuse(hc, wc) + block_k, split_block_k = s[output].split( + block_k, factor=split_block_k_nums) + nc, nci = s[output].split(nc, factor=warp_row_tiles) + block_i, nc = s[output].split(nc, factor=block_row_warps) + oc, oci = s[output].split(oc, factor=warp_col_tiles) + block_j, oc = s[output].split(oc, factor=block_col_warps) + s[output].reorder(block_k, split_block_k, block_i, + block_j, nc, oc, nci, oci, nnc, ooc) + t = s[output].fuse(nnc, ooc) + _, tx = s[output].split(t, factor=warp_size) + s[output].bind(block_k, block_z) + s[output].bind(block_i, block_x) + s[output].bind(block_j, block_y) + s[output].bind(tx, thread_x) + s[output].bind(nc, thread_y) + s[output].bind(oc, thread_z) + + # Schedule wmma store + s[OL].compute_at(s[output], block_j) + hc, wc, nc, oc, nnc, ooc = OL.op.axis + oc, oci = s[OL].split(oc, factor=warp_col_tiles) + _, oc = s[OL].split(oc, factor=block_col_warps) + nc, nci = s[OL].split(nc, factor=warp_row_tiles) + _, nc = s[OL].split(nc, factor=block_row_warps) + s[OL].reorder(nc, oc, nci, oci, nnc, ooc) + s[OL].bind(nc, thread_y) + s[OL].bind(oc, thread_z) + + # Schedule local computation + s[ConvF].compute_at(s[OL], oc) + _, _, n, o, nnf, oof = ConvF.op.axis + ko, ki = s[ConvF].split(ic, factor=chunk) + s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + + cfg.define_reorder("reorder_inner", [ko, kh], policy="all") + cfg["reorder_inner"].apply(s, ConvF, [ko, kh]) + cfg["reorder_inner"].apply(s, ConvF, [ki, kw]) + + cfg.define_knob("compute_at_AS", [0, 1, 2, 3]) + cfg.define_knob("compute_at_WS", [0, 1, 2, 3]) + compute_at_AS = cfg["compute_at_AS"].val + compute_at_WS = cfg["compute_at_WS"].val + + # Move intermediate computation into each output compute tile + s[AF].compute_at(s[ConvF], kw) + s[WF].compute_at(s[ConvF], kw) + + # Schedule for A's share memory + if compute_at_AS == 0: + s[AS].compute_at(s[ConvF], ki) + elif compute_at_AS == 1: + s[AS].compute_at(s[ConvF], kw) + elif compute_at_AS == 2: + s[AS].compute_at(s[ConvF], ko) + else: + s[AS].compute_at(s[ConvF], kh) + _, _, n, _, nn, ii = AS.op.axis + tx, xo = s[AS].split(n, nparts=block_row_warps) + ty, _ = s[AS].split(xo, nparts=block_col_warps) + t = s[AS].fuse(nn, ii) + to, ti = s[AS].split(t, nparts=warp_size) + ti, _t = s[AS].split(ti, factor=vector_as) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(to, thread_x) + s[AS].vectorize(_t) + + # Schedule for W's share memory + if compute_at_WS == 0: + s[WS].compute_at(s[ConvF], ki) + elif compute_at_WS == 1: + s[WS].compute_at(s[ConvF], kw) + elif compute_at_WS == 2: + s[WS].compute_at(s[ConvF], ko) + else: + s[WS].compute_at(s[ConvF], kh) + s[WS].compute_at(s[ConvF], kw) + kh, kw, ic, o, ii, oo = WS.op.axis + tx, xo = s[WS].split(o, nparts=block_row_warps) + ty, _ = s[WS].split(xo, nparts=block_col_warps) + t = s[WS].fuse(ii, oo) + to, ti = s[WS].split(t, nparts=warp_size) + ti, _t = s[WS].split(ti, factor=vector_ws) + s[WS].bind(tx, thread_y) + s[WS].bind(ty, thread_z) + s[WS].bind(to, thread_x) + s[WS].vectorize(ti) + + # double buffer + cfg.define_knob('AS_double_buffer', [0, 1]) + cfg.define_knob('WS_double_buffer', [0, 1]) + if cfg['AS_double_buffer'].val: + s[AS].double_buffer() + if cfg['WS_double_buffer'].val: + s[WS].double_buffer() + + # unroll + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + s[output].pragma(kernel_scope, 'auto_unroll_max_step', + cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', False) + + shape = (wmma_m, wmma_n, wmma_k) + + AS_shape = (wmma_m, wmma_k) + AL_shape = (wmma_m, wmma_k) + WS_shape = (wmma_n, wmma_k) + WL_shape = (wmma_n, wmma_k) + CL_shape = (wmma_m, wmma_n) + CS_shape = (wmma_m, wmma_n) + + AL_gemm = te.placeholder(AL_shape, name='A', dtype=data_dtype) + WL_gemm = te.placeholder(WL_shape, name='B', dtype=kernel_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name="k") + CL_compute = te.compute(CL_shape, lambda ii, jj: + te.sum((AL_gemm[ii, k_gemm].astype( + 'int32') * WL_gemm[jj, k_gemm].astype('int32')), axis=k_gemm), + name='C') + + AL_strides = [wmma_k, 1] + AS_strides = [wmma_k, 1] + WL_strides = [wmma_k, 1] + WS_strides = [wmma_k, 1] + CL_strides = [wmma_n, 1] + CS_strides = [wmma_n, 1] + + s[AF].tensorize(AF.op.axis[-2], + intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, + "row_major", AS_shape, AL_shape, data_dtype)) + + s[WF].tensorize(WF.op.axis[-2], + intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "col_major", WS_shape, WL_shape, kernel_dtype)) + + s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, + shape, out_dtype, CL_shape, CS_shape)) + + s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, + WL_strides, CL_strides, shape)) + + return s + + +@autotvm.register_topi_schedule("conv2d_HWNCnc_tensorcore.cuda") +def schedule_conv2d_hwnc_tensorcore(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_HWNCnc_tensorcore' in op.tag: + schedule_hwnc_tensorcore_cuda(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 3e6838ce1b0b3..2f19d6e126ada 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -617,9 +617,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) CHECK(op->args.size() == 1 && l); os << "(("; this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + " + << "("; this->PrintExpr(l->index, os); - os << ')'; + if (l->dtype.bits() == 4 || (l->dtype.bits() == 1 && l->dtype.is_int())) { + os << " / " << (32 / l->dtype.bits()); + } + os << "))"; } else if (op->op.same_as(builtin::tvm_struct_get())) { CHECK_EQ(op->args.size(), 3U); os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); diff --git a/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py new file mode 100644 index 0000000000000..2c071c9a266bd --- /dev/null +++ b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +"""Example code to do convolution.""" + +import numpy as np +import tvm +import os +import tvm.topi.testing +from tvm import te, autotvm, topi +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from tvm.topi.nn.util import get_pad_tuple +from tvm.topi.util import get_const_tuple + +_conv2d_hwnc_tensorcore_implement = { + "cuda": (topi.cuda.conv2d_hwnc_tensorcore, topi.cuda.schedule_conv2d_hwnc_tensorcore) +} + +def verify_conv2d_hwnc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, devices='cuda', dtype='int4'): + """Test the conv2d with tensorcore for hwnc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + # choose dtype from int4, int8 + assert dtype in ['int4', 'int8'] + + in_height = in_width = in_size + + A = te.placeholder((in_height, in_width, batch, in_channel), name='A', dtype=dtype) + W = te.placeholder((kernel, kernel, num_filter, in_channel), name='W', dtype=dtype) + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + @memoize("topi.tests.test_topi_conv2d_hwnc.verify_conv2d_hwnc") + def get_ref_data(): + if dtype == 'int4': + a_np = np.random.randint(low=-8, high=7, size=a_shape).transpose((2, 0, 1, 3)) + w_np = np.random.randint(low=-8, high=7, size=w_shape) + dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) + elif dtype == 'int8': + a_np = np.random.randint(low=-128, high=127, size=a_shape).transpose((2, 0, 1, 3)).astype(dtype) + w_np = np.random.randint(low=-128, high=127, size=w_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np.transpose((0, 1, 3, 2)), (1, 1, dilation, dilation)) + + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, c_np + + def convert_int32_into_int4(a_int32): + """ convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + I, J, K, L = a_int32.shape + a_int4 = np.zeros(shape=(I, J, K, L // 8), dtype=np.int32) + for i in range(I): + for j in range(J): + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L-l*8)): + a_int4[i, j, k, l] = a_int4[i, j, k, l] | ((a_int32[i, j, k, l * 8 + m] & 0xf) << ((7 - m) * 4)) + return a_int4 + + a_np, w_np, c_np = get_ref_data() + if dtype == 'int4': + a_np = convert_int32_into_int4(a_np) + w_np = convert_int32_into_int4(w_np) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if not nvcc.have_tensorcore(ctx.compute_version): + print("skip because gpu does not support Tensor Cores") + return + print("Running on target: %s" % device) + with tvm.target.create(device): + fcompute, fschedule = topi.testing.dispatch(device, _conv2d_hwnc_tensorcore_implement) + C = fcompute(A, W, stride, padding, dilation, dtype, 'int32') + s = fschedule([C]) + + a = tvm.nd.array(a_np.transpose((1, 2, 0, 3)), ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + + rtol = 1e-3 + tvm.testing.assert_allclose(c.asnumpy().transpose((2, 0, 1, 3)), c_np, rtol=rtol) + + check_device(devices) + + +def test_conv2d_hwnc_tensorcore(): + """Test the conv2d with tensorcore for hwnc layout""" + verify_conv2d_hwnc(8, 64, 56, 64, 3, 1, 1, dtype='int8') + verify_conv2d_hwnc(8, 64, 56, 64, 1, 1, 0, dtype='int4') + verify_conv2d_hwnc(8, 64, 56, 128, 3, 2, 1) + verify_conv2d_hwnc(8, 64, 56, 64, 1, 2, 0) + verify_conv2d_hwnc(8, 128, 28, 128, 3, 1, 1) + verify_conv2d_hwnc(8, 128, 28, 256, 3, 2, 1) + verify_conv2d_hwnc(8, 128, 28, 256, 1, 2, 0) + verify_conv2d_hwnc(8, 256, 14, 256, 3, 1, 1) + verify_conv2d_hwnc(8, 256, 14, 512, 3, 2, 1) + verify_conv2d_hwnc(8, 256, 14, 512, 1, 2, 0) + verify_conv2d_hwnc(8, 512, 9, 512, 3, 1, 1) + +if __name__ == "__main__": + test_conv2d_hwnc_tensorcore()