From 24bd3783142fc3d28234d3e4627c3a5647c31177 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 6 Nov 2020 19:19:04 +0000 Subject: [PATCH] format --- python/tvm/topi/cuda/argwhere.py | 78 ++++++------------- .../python/topi/python/test_topi_argwhere.py | 8 +- 2 files changed, 25 insertions(+), 61 deletions(-) diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index 4791eb488c354..17d2410a517f4 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -33,14 +33,14 @@ def _get_sort_func(mode=0): """Get sort function for argwhere. mode 0 for topk and others for argsort.""" - if get_global_func( - "tvm.contrib.thrust.sort", allow_missing=True - ): + if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): ret = topk_thrust if mode == 0 else argsort_thrust else: - logger.warn("It's highly recommended to enable thrust library with set(USE_THRUST ON)" - " when compiling argwhere for cuda target. Otherwise, it can result in" - " significant performance degradation or incorrect result") + logger.warn( + "It's highly recommended to enable thrust library with set(USE_THRUST ON)" + " when compiling argwhere for cuda target. Otherwise, it can result in" + " significant performance degradation or incorrect result" + ) ret = topk if mode == 0 else argsort return ret @@ -72,9 +72,7 @@ def argwhere_1d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") @@ -87,9 +85,7 @@ def argwhere_1d_ir(condition, out): with ib.if_scope(idx < a0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0]] = idx @@ -116,9 +112,7 @@ def argwhere_1d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -161,8 +155,6 @@ def argwhere_2d_ir(condition, out): a0 = condition.shape[0] a1 = condition.shape[1] - out_len = out.shape[0] * out.shape[1] - condition = ib.buffer_ptr(condition) out = ib.buffer_ptr(out) @@ -170,9 +162,7 @@ def argwhere_2d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. @@ -187,9 +177,7 @@ def argwhere_2d_ir(condition, out): with ib.if_scope(idx < (a0 * a1)): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1) @@ -217,9 +205,7 @@ def argwhere_2d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -281,9 +267,7 @@ def argwhere_3d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. @@ -301,9 +285,7 @@ def argwhere_3d_ir(condition, out): with ib.if_scope(idx < s0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 3] = fdiv(idx, s1) @@ -332,9 +314,7 @@ def argwhere_3d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -394,15 +374,13 @@ def argwhere_4d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads - + # Limit threads to a single block to make sure atomic_add works normally. tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = s0 // nthread_tx + 1 + len_inner_for = s0 // nthread_tx + 1 fdiv = tvm.tir.floordiv fmod = tvm.tir.floormod @@ -414,9 +392,7 @@ def argwhere_4d_ir(condition, out): with ib.if_scope(idx < s0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 4] = fdiv(idx, s2) @@ -446,9 +422,7 @@ def argwhere_4d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], @@ -510,9 +484,7 @@ def argwhere_5d_ir(condition, out): tmp = ib.allocate("int32", (1,), name="tmp", scope="local") one_count = tvm.tir.const(1, dtype="int32") - max_threads = int( - tvm.target.Target.current(allow_none=False).max_num_threads - ) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads # Limit threads to a single block to make sure atomic_add works normally. @@ -530,9 +502,7 @@ def argwhere_5d_ir(condition, out): with ib.if_scope(idx < s0): with ib.if_scope(condition[idx] != 0): tmp[0] = atomic_add( - tvm.tir.call_intrin( - "handle", "tir.address_of", valid_index[0] - ), + tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), one_count, ) out[tmp[0] * 5] = fdiv(idx, s3) @@ -563,9 +533,7 @@ def argwhere_5d(output_shape, condition): condition_buf = tvm.tir.decl_buffer( condition.shape, condition.dtype, "data_buf", data_alignment=8 ) - out_buf = tvm.tir.decl_buffer( - output_shape, "int32", "out_buf", data_alignment=8 - ) + out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) out = te.extern( [output_shape], diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index 5181e45b8b5d0..b9555ff48f087 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -37,9 +37,7 @@ def verify_argwhere(data_shape): out_shape = np_out.shape[0] np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype) - out_shape = te.placeholder( - shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype - ) + out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype) condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype) def check_device(device, ctx): @@ -52,9 +50,7 @@ def check_device(device, ctx): s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule) sch = s_func(out) - func = tvm.build( - sch, [out_shape, condition, out], device, name="argwhere" - ) + func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere") # print(func.imported_modules[0].get_source())