From ffbd2ac15284bab0164add35d0571c8ef7c63ff2 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 6 Nov 2020 19:19:04 +0000 Subject: [PATCH] format --- python/tvm/relay/op/strategy/cuda.py | 4 +- python/tvm/relay/op/strategy/generic.py | 12 ++- python/tvm/topi/cuda/argwhere.py | 78 ++++++------------- .../python/topi/python/test_topi_argwhere.py | 8 +- 4 files changed, 37 insertions(+), 65 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 9e32c47e8d639..11ff36fa80b5b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -857,6 +857,7 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target): ) return strategy + @argwhere_strategy.register(["cuda", "gpu"]) def argwhere_strategy_cuda(attrs, inputs, out_type, target): """argwhere cuda strategy""" @@ -864,5 +865,6 @@ def argwhere_strategy_cuda(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_argwhere(topi.cuda.argwhere), wrap_topi_schedule(topi.cuda.schedule_argwhere), - name="argwhere.cuda") + name="argwhere.cuda", + ) return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 417d48bc7ce37..10dd1dac94d66 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1183,9 +1183,11 @@ def correlation_strategy(attrs, inputs, out_type, target): ) return strategy + # argwhere def wrap_compute_argwhere(topi_compute): """wrap argwhere topi compute""" + def _compute_argwhere(attrs, inputs, out_type): output_shape = [] for s in out_type.shape: @@ -1195,13 +1197,17 @@ def _compute_argwhere(attrs, inputs, out_type): output_shape.append(te.var("any_dim", "int32")) new_output_type = ir.TensorType(output_shape, "int32") return [topi_compute(new_output_type, inputs[0])] + return _compute_argwhere + @override_native_generic_func("argwhere_strategy") def argwhere_strategy(attrs, inputs, out_type, target): """argwhere generic strategy""" strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_argwhere(topi.argwhere), - wrap_topi_schedule(topi.generic.schedule_argwhere), - name="argwhere.generic") + strategy.add_implementation( + wrap_compute_argwhere(topi.argwhere), + wrap_topi_schedule(topi.generic.schedule_argwhere), + name="argwhere.generic", + ) return strategy diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index 4791eb488c354..17d2410a517f4 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -33,14 +33,14 @@ def _get_sort_func(mode=0): """Get sort function for argwhere. mode 0 for topk and others for argsort.""" - if get_global_func( - "tvm.contrib.thrust.sort", allow_missing=True - ): + if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): ret = topk_thrust if mode == 0 else argsort_thrust else: - logger.warn("It's highly recommended to enable thrust library with set(USE_THRUST ON)" - " when compiling argwhere for cuda target. Otherwise, it can result in" - " significant performance degradation or incorrect result") + logger.warn( + "It's highly recommended to enable thrust library with set(USE_THRUST ON)" + " when compiling argwhere for cuda target. Otherwise, it can result in" + " significant performance degradation or incorrect result" + ) ret = topk if mode == 0 else argsort return ret @@ -72,9 +72,7 @@ def argwhere_1d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") @@ -87,9 +85,7 @@ def argwhere_1d_ir(condition, out): with ib.if_scope(idx < a0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0]] = idx @@ -116,9 +112,7 @@ def argwhere_1d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -161,8 +155,6 @@ def argwhere_2d_ir(condition, out): a0 = condition.shape[0] a1 = condition.shape[1] - out_len = out.shape[0] * out.shape[1] - condition = ib.buffer_ptr(condition) out = ib.buffer_ptr(out) @@ -170,9 +162,7 @@ def argwhere_2d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. @@ -187,9 +177,7 @@ def argwhere_2d_ir(condition, out): with ib.if_scope(idx < (a0 * a1)): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1) @@ -217,9 +205,7 @@ def argwhere_2d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -281,9 +267,7 @@ def argwhere_3d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. @@ -301,9 +285,7 @@ def argwhere_3d_ir(condition, out): with ib.if_scope(idx < s0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 3] = fdiv(idx, s1) @@ -332,9 +314,7 @@ def argwhere_3d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -394,15 +374,13 @@ def argwhere_4d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads - + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = s0 // nthread_tx + 1 + len_inner_for = s0 // nthread_tx + 1 fdiv = tvm.tir.floordiv fmod = tvm.tir.floormod @@ -414,9 +392,7 @@ def argwhere_4d_ir(condition, out): with ib.if_scope(idx < s0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 4] = fdiv(idx, s2) @@ -446,9 +422,7 @@ def argwhere_4d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -510,9 +484,7 @@ def argwhere_5d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. @@ -530,9 +502,7 @@ def argwhere_5d_ir(condition, out): with ib.if_scope(idx < s0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 5] = fdiv(idx, s3) @@ -563,9 +533,7 @@ def argwhere_5d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index 5181e45b8b5d0..b9555ff48f087 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -37,9 +37,7 @@ def verify_argwhere(data_shape): out_shape = np_out.shape[0] np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype) - out_shape = te.placeholder( - shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype - ) + out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype) condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype) def check_device(device, ctx): @@ -52,9 +50,7 @@ def check_device(device, ctx): s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule) sch = s_func(out) - func = tvm.build( - sch, [out_shape, condition, out], device, name="argwhere" - ) + func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere") # print(func.imported_modules[0].get_source())