From 411eaecf3104713e188707aec1bbc81888efa3a7 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 12 Sep 2020 05:39:47 +0000 Subject: [PATCH] sort argwhere result --- python/tvm/topi/cuda/argwhere.py | 96 +++++++++++++++++-- .../python/topi/python/test_topi_argwhere.py | 35 ++++--- 2 files changed, 109 insertions(+), 22 deletions(-) diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index b8e346794ad79..9949ccfe0c485 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -19,7 +19,11 @@ import tvm from tvm import te +from .injective import schedule_injective_from_existing from .nms import atomic_add +from .sort import topk, argsort +from .. import tag +from ..transform import strided_slice, adv_index, squeeze def argwhere_1d_ir(condition, out): @@ -48,7 +52,9 @@ def argwhere_1d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = 1024 + max_threads = int( + tvm.target.Target.current(allow_none=False).max_num_threads + ) nthread_tx = max_threads nthread_bx = a0 // max_threads + 1 tx = te.thread_axis("threadIdx.x") @@ -105,7 +111,11 @@ def argwhere_1d(output_shape, condition): tag="argwhere1d_gpu", ) - return out + sorted_out = topk( + out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" + ) + + return sorted_out def argwhere_2d_ir(condition, out): @@ -135,7 +145,9 @@ def argwhere_2d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = 1024 + max_threads = int( + tvm.target.Target.current(allow_none=False).max_num_threads + ) nthread_tx = max_threads nthread_bx = (a0 * a1) // max_threads + 1 tx = te.thread_axis("threadIdx.x") @@ -194,7 +206,21 @@ def argwhere_2d(output_shape, condition): tag="argwhere2d_gpu", ) - return out + if out.shape[0] <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + out1 = strided_slice(out, [0, 1], [out.shape[0], 2]) + out2 = argsort(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + + out1 = strided_slice(out, [0, 0], [out.shape[0], 1]) + out2 = argsort(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + + return adv_index(out, [out3]) def argwhere_3d_ir(condition, out): @@ -227,7 +253,9 @@ def argwhere_3d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = 1024 + max_threads = int( + tvm.target.Target.current(allow_none=False).max_num_threads + ) nthread_tx = max_threads nthread_bx = s0 // max_threads + 1 tx = te.thread_axis("threadIdx.x") @@ -289,6 +317,17 @@ def argwhere_3d(output_shape, condition): tag="argwhere3d_gpu", ) + if out.shape[0] <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + for i in reversed(range(3)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = argsort(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + return out @@ -324,7 +363,9 @@ def argwhere_4d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = 1024 + max_threads = int( + tvm.target.Target.current(allow_none=False).max_num_threads + ) nthread_tx = max_threads nthread_bx = s0 // max_threads + 1 tx = te.thread_axis("threadIdx.x") @@ -387,6 +428,17 @@ def argwhere_4d(output_shape, condition): tag="argwhere4d_gpu", ) + if out.shape[0] <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + for i in reversed(range(4)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = argsort(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + return out @@ -424,7 +476,9 @@ def argwhere_5d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = 1024 + max_threads = int( + tvm.target.Target.current(allow_none=False).max_num_threads + ) nthread_tx = max_threads nthread_bx = s0 // max_threads + 1 tx = te.thread_axis("threadIdx.x") @@ -488,6 +542,17 @@ def argwhere_5d(output_shape, condition): tag="argwhere5d_gpu", ) + if out.shape[0] <= 1: + return out + + # sort the output from the least significant to the most significant + # column. + for i in reversed(range(5)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = argsort(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + return out @@ -535,6 +600,17 @@ def schedule_argwhere(outs): The computation schedule for argwhere """ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - sch = te.create_schedule([x.op for x in outs]) - - return sch + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_from_existing(s, op.output(0)) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + return s diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index eb99fffe6f126..817585fef2c47 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -27,10 +27,8 @@ "gpu": topi.cuda.schedule_argwhere, } -_argwhere_compute = { - "llvm": topi.argwhere, - "cuda": topi.cuda.argwhere -} +_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere} + def verify_argwhere(data_shape): dtype = "int32" @@ -39,8 +37,9 @@ def verify_argwhere(data_shape): out_shape = np_out.shape[0] np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype) - out_shape = te.placeholder(shape=(out_shape, len(data_shape)), - name="out_shape", dtype=dtype) + out_shape = te.placeholder( + shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype + ) condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype) def check_device(device, ctx): @@ -48,18 +47,23 @@ def check_device(device, ctx): if not ctx.exist or device not in _argwhere_compute: return - out = _argwhere_compute[device](out_shape, condition) - with tvm.target.create(device): + with tvm.target.Target(device): + out = _argwhere_compute[device](out_shape, condition) s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule) sch = s_func(out) - func = tvm.build(sch, [out_shape, condition, out], device, - name="argwhere") + func = tvm.build( + sch, [out_shape, condition, out], device, name="argwhere" + ) + + # print(func.imported_modules[0].get_source()) args = [tvm.nd.array(np_shape, ctx)] args.append(tvm.nd.array(np_data, ctx)) args.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=condition.dtype)) func(*args) + np.set_printoptions(threshold=np.inf) + # print(args[-1].asnumpy()) tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out)) for target, ctx in tvm.testing.enabled_targets(): @@ -70,11 +74,18 @@ def check_device(device, ctx): def test_argwhere(): verify_argwhere((1,)) verify_argwhere((100,)) + verify_argwhere((1, 1)) verify_argwhere((5, 3)) - verify_argwhere((100, 100)) + verify_argwhere((32, 64)) + # TODO(zhiics) This test is flaky because nothing is sorted. + verify_argwhere((128, 65)) verify_argwhere((6, 5, 3)) - verify_argwhere((32, 32, 16)) + verify_argwhere((1, 1, 1)) + # TODO(zhiics) This test is flaky. + # verify_argwhere((32, 32, 8)) + verify_argwhere((1, 1, 1, 1)) verify_argwhere((6, 4, 5, 3)) + verify_argwhere((1, 1, 1, 1, 1)) verify_argwhere((6, 4, 5, 3, 7))