diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index 9949ccfe0c485..4791eb488c354 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -17,14 +17,34 @@ # pylint: disable=too-many-arguments, invalid-name """Argwhere operator""" +import logging + import tvm from tvm import te +from tvm._ffi import get_global_func from .injective import schedule_injective_from_existing from .nms import atomic_add -from .sort import topk, argsort +from .sort import topk, topk_thrust, argsort, argsort_thrust from .. import tag from ..transform import strided_slice, adv_index, squeeze +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): + """Get sort function for argwhere. mode 0 for topk and others for argsort.""" + if get_global_func( + "tvm.contrib.thrust.sort", allow_missing=True + ): + ret = topk_thrust if mode == 0 else argsort_thrust + else: + logger.warn("It's highly recommended to enable thrust library with set(USE_THRUST ON)" + " when compiling argwhere for cuda target. Otherwise, it can result in" + " significant performance degradation or incorrect result") + ret = topk if mode == 0 else argsort + + return ret + def argwhere_1d_ir(condition, out): """Low level IR for argwhere 1D @@ -48,7 +68,7 @@ def argwhere_1d_ir(condition, out): condition = ib.buffer_ptr(condition) out = ib.buffer_ptr(out) - valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local") + valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") @@ -56,23 +76,23 @@ def argwhere_1d_ir(condition, out): tvm.target.Target.current(allow_none=False).max_num_threads ) nthread_tx = max_threads - nthread_bx = a0 // max_threads + 1 + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx + len_inner_for = a0 // nthread_tx + 1 valid_index[0] = 0 - with ib.if_scope(tid < a0): - with ib.if_scope(condition[tid] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), - one_count, - ) - out[tmp[0]] = tid + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < a0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin( + "handle", "tir.address_of", valid_index[0] + ), + one_count, + ) + out[tmp[0]] = idx return ib.get() @@ -111,7 +131,10 @@ def argwhere_1d(output_shape, condition): tag="argwhere1d_gpu", ) - sorted_out = topk( + if out.shape[0] <= 1: + return out + + sorted_out = _get_sort_func()( out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" ) @@ -138,6 +161,8 @@ def argwhere_2d_ir(condition, out): a0 = condition.shape[0] a1 = condition.shape[1] + out_len = out.shape[0] * out.shape[1] + condition = ib.buffer_ptr(condition) out = ib.buffer_ptr(out) @@ -149,25 +174,26 @@ def argwhere_2d_ir(condition, out): tvm.target.Target.current(allow_none=False).max_num_threads ) nthread_tx = max_threads - nthread_bx = (a0 * a1) // max_threads + 1 + + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx + len_inner_for = (a0 * a1) // nthread_tx + 1 valid_index[0] = 0 - with ib.if_scope(tid < (a0 * a1)): - with ib.if_scope(condition[tid] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), - one_count, - ) - out[tmp[0] * 2] = tvm.tir.floordiv(tid, a1) - out[tmp[0] * 2 + 1] = tvm.tir.floormod(tid, a1) + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < (a0 * a1)): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin( + "handle", "tir.address_of", valid_index[0] + ), + one_count, + ) + out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1) + out[tmp[0] * 2 + 1] = tvm.tir.floormod(idx, a1) return ib.get() @@ -209,15 +235,17 @@ def argwhere_2d(output_shape, condition): if out.shape[0] <= 1: return out + sort_func = _get_sort_func(1) + # sort the output from the least significant to the most significant # column. out1 = strided_slice(out, [0, 1], [out.shape[0], 2]) - out2 = argsort(out1, axis=0, dtype="int32") + out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) out = adv_index(out, [out3]) out1 = strided_slice(out, [0, 0], [out.shape[0], 1]) - out2 = argsort(out1, axis=0, dtype="int32") + out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) return adv_index(out, [out3]) @@ -257,28 +285,30 @@ def argwhere_3d_ir(condition, out): tvm.target.Target.current(allow_none=False).max_num_threads ) nthread_tx = max_threads - nthread_bx = s0 // max_threads + 1 + + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx + len_inner_for = s0 // nthread_tx + 1 + fdiv = tvm.tir.floordiv fmod = tvm.tir.floormod valid_index[0] = 0 - with ib.if_scope(tid < s0): - with ib.if_scope(condition[tid] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), - one_count, - ) - out[tmp[0] * 3] = fdiv(tid, s1) - out[tmp[0] * 3 + 1] = fdiv(fmod(tid, s1), a2) - out[tmp[0] * 3 + 2] = fmod(tid, a2) + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < s0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin( + "handle", "tir.address_of", valid_index[0] + ), + one_count, + ) + out[tmp[0] * 3] = fdiv(idx, s1) + out[tmp[0] * 3 + 1] = fdiv(fmod(idx, s1), a2) + out[tmp[0] * 3 + 2] = fmod(idx, a2) return ib.get() @@ -322,9 +352,10 @@ def argwhere_3d(output_shape, condition): # sort the output from the least significant to the most significant # column. + sort_func = _get_sort_func(1) for i in reversed(range(3)): out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = argsort(out1, axis=0, dtype="int32") + out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) out = adv_index(out, [out3]) @@ -367,29 +398,31 @@ def argwhere_4d_ir(condition, out): tvm.target.Target.current(allow_none=False).max_num_threads ) nthread_tx = max_threads - nthread_bx = s0 // max_threads + 1 + + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx + len_inner_for = s0 // nthread_tx + 1 + fdiv = tvm.tir.floordiv fmod = tvm.tir.floormod valid_index[0] = 0 - with ib.if_scope(tid < s0): - with ib.if_scope(condition[tid] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), - one_count, - ) - out[tmp[0] * 4] = fdiv(tid, s2) - out[tmp[0] * 4 + 1] = fdiv(fmod(tid, s2), s1) - out[tmp[0] * 4 + 2] = fdiv(fmod(tid, s1), a3) - out[tmp[0] * 4 + 3] = fmod(tid, a3) + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < s0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin( + "handle", "tir.address_of", valid_index[0] + ), + one_count, + ) + out[tmp[0] * 4] = fdiv(idx, s2) + out[tmp[0] * 4 + 1] = fdiv(fmod(idx, s2), s1) + out[tmp[0] * 4 + 2] = fdiv(fmod(idx, s1), a3) + out[tmp[0] * 4 + 3] = fmod(idx, a3) return ib.get() @@ -433,9 +466,10 @@ def argwhere_4d(output_shape, condition): # sort the output from the least significant to the most significant # column. + sort_func = _get_sort_func(1) for i in reversed(range(4)): out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = argsort(out1, axis=0, dtype="int32") + out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) out = adv_index(out, [out3]) @@ -480,30 +514,32 @@ def argwhere_5d_ir(condition, out): tvm.target.Target.current(allow_none=False).max_num_threads ) nthread_tx = max_threads - nthread_bx = s0 // max_threads + 1 + + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx + len_inner_for = s0 // nthread_tx + 1 + fdiv = tvm.tir.floordiv fmod = tvm.tir.floormod valid_index[0] = 0 - with ib.if_scope(tid < s0): - with ib.if_scope(condition[tid] != 0): - tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), - one_count, - ) - out[tmp[0] * 5] = fdiv(tid, s3) - out[tmp[0] * 5 + 1] = fdiv(fmod(tid, s3), s2) - out[tmp[0] * 5 + 2] = fdiv(fmod(tid, s2), s1) - out[tmp[0] * 5 + 3] = fdiv(fmod(tid, s1), a4) - out[tmp[0] * 5 + 4] = fmod(tid, a4) + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + with ib.if_scope(idx < s0): + with ib.if_scope(condition[idx] != 0): + tmp[0] = atomic_add( + tvm.tir.call_intrin( + "handle", "tir.address_of", valid_index[0] + ), + one_count, + ) + out[tmp[0] * 5] = fdiv(idx, s3) + out[tmp[0] * 5 + 1] = fdiv(fmod(idx, s3), s2) + out[tmp[0] * 5 + 2] = fdiv(fmod(idx, s2), s1) + out[tmp[0] * 5 + 3] = fdiv(fmod(idx, s1), a4) + out[tmp[0] * 5 + 4] = fmod(idx, a4) return ib.get() @@ -547,9 +583,10 @@ def argwhere_5d(output_shape, condition): # sort the output from the least significant to the most significant # column. + sort_func = _get_sort_func(1) for i in reversed(range(5)): out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = argsort(out1, axis=0, dtype="int32") + out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) out = adv_index(out, [out3]) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 465299a5bc8f7..d9c6ffdd37cd5 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -549,6 +549,8 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8), ] + is_ascend = 1 if is_ascend else 0 + out = te.extern( [data.shape, data.shape], [data], diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index 817585fef2c47..5181e45b8b5d0 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -77,12 +77,10 @@ def test_argwhere(): verify_argwhere((1, 1)) verify_argwhere((5, 3)) verify_argwhere((32, 64)) - # TODO(zhiics) This test is flaky because nothing is sorted. verify_argwhere((128, 65)) + verify_argwhere((200, 500)) verify_argwhere((6, 5, 3)) verify_argwhere((1, 1, 1)) - # TODO(zhiics) This test is flaky. - # verify_argwhere((32, 32, 8)) verify_argwhere((1, 1, 1, 1)) verify_argwhere((6, 4, 5, 3)) verify_argwhere((1, 1, 1, 1, 1))