Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 21, 2021
1 parent 4179bf1 commit 63469a6
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions python/tvm/topi/cuda/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,29 @@


def compact_nonzero_indices_ir(condition, write_indices, out, do_write_func):
"""Copy nonzero indices to the corresponding write locations.
Parameters
----------
condition : Buffer
The input condition.
write_indices : Buffer
The result of exclusive scan on a boolean array, where True indicates that
the condition is non zero at that position.
out : Buffer
The output buffer to copy indices to.
do_write_func : a function
A callback that accepts an output buffer, a dst index to write to, and a src index.
Returns
-------
stmt : Stmt
The result IR statement.
"""
TODO
"""

ib = tvm.tir.ir_builder.create()
size_1d = prod(condition.shape)

Expand All @@ -64,6 +84,25 @@ def compact_nonzero_indices_ir(condition, write_indices, out, do_write_func):


def argwhere_common(output_shape, condition, do_write_func):
"""A common compute used by argwhere of various ranks.
Parameters
----------
output_shape : list of int or tvm.tir.Any
Tensor with output shape info.
condition : tvm.te.Tensor
The input condition.
do_write_func : a function
A callback that accepts an output buffer, a dst index to write to, and a src index.
Returns
-------
out : tvm.te.Tensor
Indices of non-zero elements.
"""

flags = not_equal(condition, tvm.tir.const(0))
flags_1d = reshape(flags, (prod(flags.shape),))
write_indices = exclusive_scan(cast(flags_1d, dtype="int32"))
Expand Down

0 comments on commit 63469a6

Please sign in to comment.