From 85a91e93c5e9a2a0ec14d5fa6f41e43ac18b683b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 19:31:29 +0900 Subject: [PATCH] add doc --- python/tvm/topi/cuda/argwhere.py | 43 ++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index b70ddb8a3abaf..cc6c4c26eddbc 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -37,9 +37,29 @@ def compact_nonzero_indices_ir(condition, write_indices, out, do_write_func): + """Copy nonzero indices to the corresponding write locations. + + Parameters + ---------- + condition : Buffer + The input condition. + + write_indices : Buffer + The result of exclusive scan on a boolean array, where True indicates that + the condition is non zero at that position. + + out : Buffer + The output buffer to copy indices to. + + do_write_func : a function + A callback that accepts an output buffer, a dst index to write to, and a src index. + + Returns + ------- + stmt : Stmt + The result IR statement. """ - TODO - """ + ib = tvm.tir.ir_builder.create() size_1d = prod(condition.shape) @@ -64,6 +84,25 @@ def compact_nonzero_indices_ir(condition, write_indices, out, do_write_func): def argwhere_common(output_shape, condition, do_write_func): + """A common compute used by argwhere of various ranks. + + Parameters + ---------- + output_shape : list of int or tvm.tir.Any + Tensor with output shape info. + + condition : tvm.te.Tensor + The input condition. + + do_write_func : a function + A callback that accepts an output buffer, a dst index to write to, and a src index. + + Returns + ------- + out : tvm.te.Tensor + Indices of non-zero elements. + """ + flags = not_equal(condition, tvm.tir.const(0)) flags_1d = reshape(flags, (prod(flags.shape),)) write_indices = exclusive_scan(cast(flags_1d, dtype="int32"))