From 113331447d6d8d79b1cec89d6c4c1814c0ac3795 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 21 Jan 2021 15:05:44 +0900 Subject: [PATCH] [TOPI] Rewrite GPU argwhere using exclusive scan (#7314) * use ex scan to write argwhere * add doc --- python/tvm/contrib/nvcc.py | 2 +- python/tvm/topi/cuda/argwhere.py | 524 ++++++------------------------- 2 files changed, 98 insertions(+), 428 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index bc11e4a867e4..5886760934fb 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -186,7 +186,7 @@ def find_libdevice_path(arch): selected_ver = 0 selected_path = None cuda_ver = get_cuda_version(cuda_path) - if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1): + if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2): path = os.path.join(lib_path, "libdevice.10.bc") else: for fn in os.listdir(lib_path): diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index e39004dc76a9..cc6c4c26eddb 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -21,169 +21,135 @@ import tvm from tvm import te -from tvm._ffi import get_global_func from .injective import schedule_injective_from_existing -from .nms import atomic_add -from .sort import topk, topk_thrust, argsort, argsort_thrust +from .scan import exclusive_scan from .. import tag -from ..transform import strided_slice, adv_index, squeeze - -logger = logging.getLogger("topi") +from ..utils import ceil_div, prod +from ..transform import reshape +from ..broadcast import not_equal +from ..math import cast -def _get_sort_func(mode=0): - """Get sort function for argwhere. mode 0 for topk and others for argsort.""" - if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): - ret = topk_thrust if mode == 0 else argsort_thrust - else: - logger.warning( - "It's highly recommended to enable thrust library with set(USE_THRUST ON)" - " when compiling argwhere for cuda target. Otherwise, it can result in" - " significant performance degradation or incorrect result" - ) - ret = topk if mode == 0 else argsort +logger = logging.getLogger("topi") - return ret +fdiv = tvm.tir.floordiv +fmod = tvm.tir.floormod -def argwhere_1d_ir(condition, out): - """Low level IR for argwhere 1D +def compact_nonzero_indices_ir(condition, write_indices, out, do_write_func): + """Copy nonzero indices to the corresponding write locations. Parameters ---------- condition : Buffer - The condition buffer. + The input condition. + + write_indices : Buffer + The result of exclusive scan on a boolean array, where True indicates that + the condition is non zero at that position. out : Buffer - The output buffer. + The output buffer to copy indices to. + + do_write_func : a function + A callback that accepts an output buffer, a dst index to write to, and a src index. Returns ------- stmt : Stmt The result IR statement. """ + ib = tvm.tir.ir_builder.create() - a0 = condition.shape[0] + size_1d = prod(condition.shape) condition = ib.buffer_ptr(condition) + write_indices = ib.buffer_ptr(write_indices) out = ib.buffer_ptr(out) - valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") - tmp = ib.allocate("int32", (1,), name="tmp", scope="local") - one_count = tvm.tir.const(1, dtype="int32") - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - # Limit threads to a single block to make sure atomic_add works normally. + nthread_tx = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_bx = ceil_div(size_1d, nthread_tx) tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = a0 // nthread_tx + 1 - valid_index[0] = 0 + ib.scope_attr(bx, "thread_extent", nthread_bx) - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - with ib.if_scope(idx < a0): + with ib.new_scope(): + idx = bx * nthread_tx + tx + with ib.if_scope(idx < size_1d): with ib.if_scope(condition[idx] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), - one_count, - ) - out[tmp[0]] = idx + do_write_func(out, write_indices[idx], idx) return ib.get() -def argwhere_1d(output_shape, condition): - """Compute for argwhere 1D +def argwhere_common(output_shape, condition, do_write_func): + """A common compute used by argwhere of various ranks. Parameters ---------- - condition : list of int or tvm.tir.Any - The output shape + output_shape : list of int or tvm.tir.Any + Tensor with output shape info. - out : tvm.te.Tensor - Tensor with boolean values. + condition : tvm.te.Tensor + The input condition. + + do_write_func : a function + A callback that accepts an output buffer, a dst index to write to, and a src index. Returns ------- - stmt : Stmt - The result IR statement. + out : tvm.te.Tensor + Indices of non-zero elements. """ + + flags = not_equal(condition, tvm.tir.const(0)) + flags_1d = reshape(flags, (prod(flags.shape),)) + write_indices = exclusive_scan(cast(flags_1d, dtype="int32")) + condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) + write_indices_buf = tvm.tir.decl_buffer( + write_indices.shape, write_indices.dtype, "write_indices_buf", data_alignment=8 + ) out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], - [condition], - lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), + [condition, write_indices], + lambda ins, outs: compact_nonzero_indices_ir(ins[0], ins[1], outs[0], do_write_func), dtype=["int32"], - in_buffers=[condition_buf], + in_buffers=[condition_buf, write_indices_buf], out_buffers=[out_buf], - name="argwhere_1d", - tag="argwhere1d_gpu", + name="argwhere", + tag="argwhere_gpu", ) - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: - return out - - sorted_out = _get_sort_func()( - out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" - ) - - return sorted_out + return out -def argwhere_2d_ir(condition, out): - """Low level IR for argwhere 2D +def argwhere_1d(output_shape, condition): + """Compute for argwhere 1D Parameters ---------- - condition : Buffer - The condition buffer. + condition : list of int or tvm.tir.Any + The output shape - out : Buffer - The output buffer. + out : tvm.te.Tensor + Tensor with boolean values. Returns ------- stmt : Stmt The result IR statement. """ - ib = tvm.tir.ir_builder.create() - a0 = condition.shape[0] - a1 = condition.shape[1] - condition = ib.buffer_ptr(condition) - out = ib.buffer_ptr(out) + def do_write(out, write_index, idx): + out[write_index] = idx - valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") - tmp = ib.allocate("int32", (1,), name="tmp", scope="local") - one_count = tvm.tir.const(1, dtype="int32") - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - - # Limit threads to a single block to make sure atomic_add works normally. - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = (a0 * a1) // nthread_tx + 1 - - valid_index[0] = 0 - - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - with ib.if_scope(idx < (a0 * a1)): - with ib.if_scope(condition[idx] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), - one_count, - ) - out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1) - out[tmp[0] * 2 + 1] = tvm.tir.floormod(idx, a1) - - return ib.get() + return argwhere_common(output_shape, condition, do_write) def argwhere_2d(output_shape, condition): @@ -202,109 +168,13 @@ def argwhere_2d(output_shape, condition): stmt : Stmt The result IR statement. """ - condition_buf = tvm.tir.decl_buffer( - condition.shape, condition.dtype, "data_buf", data_alignment=8 - ) - out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) - - out = te.extern( - [output_shape], - [condition], - lambda ins, outs: argwhere_2d_ir(ins[0], outs[0]), - dtype=["int32"], - in_buffers=[condition_buf], - out_buffers=[out_buf], - name="argwhere_2d", - tag="argwhere2d_gpu", - ) - - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: - return out - - sort_func = _get_sort_func(1) - - # sort the output from the least significant to the most significant - # column. - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): - out1 = strided_slice(out, [0, 1], [out.shape[0], 2]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - - out1 = strided_slice(out, [0, 0], [out.shape[0], 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - - out = adv_index(out, [out3]) - else: - out1 = strided_slice(out, [0, 1], [out.shape[0], 2], [1, 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - - out1 = strided_slice(out, [0, 0], [out.shape[0], 1], [1, 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - return out - - -def argwhere_3d_ir(condition, out): - """Low level IR for argwhere 3D - - Parameters - ---------- - condition : Buffer - The condition buffer. - - out : Buffer - The output buffer. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - ib = tvm.tir.ir_builder.create() - a0 = condition.shape[0] - a1 = condition.shape[1] - a2 = condition.shape[2] - s1 = a1 * a2 - s0 = a0 * s1 - - condition = ib.buffer_ptr(condition) - out = ib.buffer_ptr(out) - - valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") - tmp = ib.allocate("int32", (1,), name="tmp", scope="local") - one_count = tvm.tir.const(1, dtype="int32") - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads + def do_write(out, write_index, idx): + a1 = condition.shape[1] + out[write_index * 2] = tvm.tir.floordiv(idx, a1) + out[write_index * 2 + 1] = tvm.tir.floormod(idx, a1) - # Limit threads to a single block to make sure atomic_add works normally. - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = s0 // nthread_tx + 1 - - fdiv = tvm.tir.floordiv - fmod = tvm.tir.floormod - - valid_index[0] = 0 - - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - with ib.if_scope(idx < s0): - with ib.if_scope(condition[idx] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), - one_count, - ) - out[tmp[0] * 3] = fdiv(idx, s1) - out[tmp[0] * 3 + 1] = fdiv(fmod(idx, s1), a2) - out[tmp[0] * 3 + 2] = fmod(idx, a2) - - return ib.get() + return argwhere_common(output_shape, condition, do_write) def argwhere_3d(output_shape, condition): @@ -323,103 +193,15 @@ def argwhere_3d(output_shape, condition): stmt : Stmt The result IR statement. """ - condition_buf = tvm.tir.decl_buffer( - condition.shape, condition.dtype, "data_buf", data_alignment=8 - ) - out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) - - out = te.extern( - [output_shape], - [condition], - lambda ins, outs: argwhere_3d_ir(ins[0], outs[0]), - dtype=["int32"], - in_buffers=[condition_buf], - out_buffers=[out_buf], - name="argwhere_3d", - tag="argwhere3d_gpu", - ) - - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: - return out - - # sort the output from the least significant to the most significant - # column. - sort_func = _get_sort_func(1) - - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): - for i in reversed(range(3)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - else: - for i in reversed(range(3)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - return out - - -def argwhere_4d_ir(condition, out): - """Low level IR for argwhere 4D - - Parameters - ---------- - condition : Buffer - The condition buffer. - - out : Buffer - The output buffer. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - ib = tvm.tir.ir_builder.create() - a0 = condition.shape[0] - a1 = condition.shape[1] - a2 = condition.shape[2] - a3 = condition.shape[3] - s1 = a2 * a3 - s2 = a1 * s1 - s0 = a0 * s2 - - condition = ib.buffer_ptr(condition) - out = ib.buffer_ptr(out) - - valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") - tmp = ib.allocate("int32", (1,), name="tmp", scope="local") - one_count = tvm.tir.const(1, dtype="int32") - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - - # Limit threads to a single block to make sure atomic_add works normally. - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = s0 // nthread_tx + 1 - - fdiv = tvm.tir.floordiv - fmod = tvm.tir.floormod - valid_index[0] = 0 + def do_write(out, write_index, idx): + _, a1, a2 = condition.shape + s1 = a1 * a2 + out[write_index * 3] = fdiv(idx, s1) + out[write_index * 3 + 1] = fdiv(fmod(idx, s1), a2) + out[write_index * 3 + 2] = fmod(idx, a2) - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - with ib.if_scope(idx < s0): - with ib.if_scope(condition[idx] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), - one_count, - ) - out[tmp[0] * 4] = fdiv(idx, s2) - out[tmp[0] * 4 + 1] = fdiv(fmod(idx, s2), s1) - out[tmp[0] * 4 + 2] = fdiv(fmod(idx, s1), a3) - out[tmp[0] * 4 + 3] = fmod(idx, a3) - - return ib.get() + return argwhere_common(output_shape, condition, do_write) def argwhere_4d(output_shape, condition): @@ -438,106 +220,17 @@ def argwhere_4d(output_shape, condition): stmt : Stmt The result IR statement. """ - condition_buf = tvm.tir.decl_buffer( - condition.shape, condition.dtype, "data_buf", data_alignment=8 - ) - out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) - - out = te.extern( - [output_shape], - [condition], - lambda ins, outs: argwhere_4d_ir(ins[0], outs[0]), - dtype=["int32"], - in_buffers=[condition_buf], - out_buffers=[out_buf], - name="argwhere_4d", - tag="argwhere4d_gpu", - ) - - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: - return out - - # sort the output from the least significant to the most significant - # column. - sort_func = _get_sort_func(1) - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): - for i in reversed(range(4)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - else: - for i in reversed(range(4)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - - return out - - -def argwhere_5d_ir(condition, out): - """Low level IR for argwhere 5D - - Parameters - ---------- - condition : Buffer - The condition buffer. - - out : Buffer - The output buffer. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - ib = tvm.tir.ir_builder.create() - a0 = condition.shape[0] - a1 = condition.shape[1] - a2 = condition.shape[2] - a3 = condition.shape[3] - a4 = condition.shape[4] - s1 = a3 * a4 - s2 = a2 * s1 - s3 = a1 * s2 - s0 = a0 * s3 - - condition = ib.buffer_ptr(condition) - out = ib.buffer_ptr(out) - - valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") - tmp = ib.allocate("int32", (1,), name="tmp", scope="local") - one_count = tvm.tir.const(1, dtype="int32") - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - # Limit threads to a single block to make sure atomic_add works normally. - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = s0 // nthread_tx + 1 + def do_write(out, write_index, idx): + _, a1, a2, a3 = condition.shape + s1 = a2 * a3 + s2 = a1 * s1 + out[write_index * 4] = fdiv(idx, s2) + out[write_index * 4 + 1] = fdiv(fmod(idx, s2), s1) + out[write_index * 4 + 2] = fdiv(fmod(idx, s1), a3) + out[write_index * 4 + 3] = fmod(idx, a3) - fdiv = tvm.tir.floordiv - fmod = tvm.tir.floormod - - valid_index[0] = 0 - - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - with ib.if_scope(idx < s0): - with ib.if_scope(condition[idx] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), - one_count, - ) - out[tmp[0] * 5] = fdiv(idx, s3) - out[tmp[0] * 5 + 1] = fdiv(fmod(idx, s3), s2) - out[tmp[0] * 5 + 2] = fdiv(fmod(idx, s2), s1) - out[tmp[0] * 5 + 3] = fdiv(fmod(idx, s1), a4) - out[tmp[0] * 5 + 4] = fmod(idx, a4) - - return ib.get() + return argwhere_common(output_shape, condition, do_write) def argwhere_5d(output_shape, condition): @@ -556,42 +249,19 @@ def argwhere_5d(output_shape, condition): stmt : Stmt The result IR statement. """ - condition_buf = tvm.tir.decl_buffer( - condition.shape, condition.dtype, "data_buf", data_alignment=8 - ) - out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) - out = te.extern( - [output_shape], - [condition], - lambda ins, outs: argwhere_5d_ir(ins[0], outs[0]), - dtype=["int32"], - in_buffers=[condition_buf], - out_buffers=[out_buf], - name="argwhere_5d", - tag="argwhere5d_gpu", - ) - - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: - return out - - # sort the output from the least significant to the most significant - # column. - sort_func = _get_sort_func(1) - if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): - for i in reversed(range(5)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - else: - for i in reversed(range(5)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1], [1, 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) - - return out + def do_write(out, write_index, idx): + _, a1, a2, a3, a4 = condition.shape + s1 = a3 * a4 + s2 = a2 * s1 + s3 = a1 * s2 + out[write_index * 5] = fdiv(idx, s3) + out[write_index * 5 + 1] = fdiv(fmod(idx, s3), s2) + out[write_index * 5 + 2] = fdiv(fmod(idx, s2), s1) + out[write_index * 5 + 3] = fdiv(fmod(idx, s1), a4) + out[write_index * 5 + 4] = fmod(idx, a4) + + return argwhere_common(output_shape, condition, do_write) def argwhere(output_shape, condition):