Skip to content

Commit

Permalink
sort argwhere result
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Nov 6, 2020
1 parent 5beabe3 commit 411eaec
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 22 deletions.
96 changes: 86 additions & 10 deletions python/tvm/topi/cuda/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

import tvm
from tvm import te
from .injective import schedule_injective_from_existing
from .nms import atomic_add
from .sort import topk, argsort
from .. import tag
from ..transform import strided_slice, adv_index, squeeze


def argwhere_1d_ir(condition, out):
Expand Down Expand Up @@ -48,7 +52,9 @@ def argwhere_1d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = 1024
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = a0 // max_threads + 1
tx = te.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -105,7 +111,11 @@ def argwhere_1d(output_shape, condition):
tag="argwhere1d_gpu",
)

return out
sorted_out = topk(
out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32"
)

return sorted_out


def argwhere_2d_ir(condition, out):
Expand Down Expand Up @@ -135,7 +145,9 @@ def argwhere_2d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = 1024
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = (a0 * a1) // max_threads + 1
tx = te.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -194,7 +206,21 @@ def argwhere_2d(output_shape, condition):
tag="argwhere2d_gpu",
)

return out
if out.shape[0] <= 1:
return out

# sort the output from the least significant to the most significant
# column.
out1 = strided_slice(out, [0, 1], [out.shape[0], 2])
out2 = argsort(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

out1 = strided_slice(out, [0, 0], [out.shape[0], 1])
out2 = argsort(out1, axis=0, dtype="int32")
out3 = squeeze(out2)

return adv_index(out, [out3])


def argwhere_3d_ir(condition, out):
Expand Down Expand Up @@ -227,7 +253,9 @@ def argwhere_3d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = 1024
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = s0 // max_threads + 1
tx = te.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -289,6 +317,17 @@ def argwhere_3d(output_shape, condition):
tag="argwhere3d_gpu",
)

if out.shape[0] <= 1:
return out

# sort the output from the least significant to the most significant
# column.
for i in reversed(range(3)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = argsort(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

return out


Expand Down Expand Up @@ -324,7 +363,9 @@ def argwhere_4d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = 1024
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = s0 // max_threads + 1
tx = te.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -387,6 +428,17 @@ def argwhere_4d(output_shape, condition):
tag="argwhere4d_gpu",
)

if out.shape[0] <= 1:
return out

# sort the output from the least significant to the most significant
# column.
for i in reversed(range(4)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = argsort(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

return out


Expand Down Expand Up @@ -424,7 +476,9 @@ def argwhere_5d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = 1024
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = s0 // max_threads + 1
tx = te.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -488,6 +542,17 @@ def argwhere_5d(output_shape, condition):
tag="argwhere5d_gpu",
)

if out.shape[0] <= 1:
return out

# sort the output from the least significant to the most significant
# column.
for i in reversed(range(5)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = argsort(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

return out


Expand Down Expand Up @@ -535,6 +600,17 @@ def schedule_argwhere(outs):
The computation schedule for argwhere
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
sch = te.create_schedule([x.op for x in outs])

return sch
s = te.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
if tag.is_injective(op.tag):
schedule_injective_from_existing(s, op.output(0))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)

for out in outs:
traverse(out.op)
return s
35 changes: 23 additions & 12 deletions tests/python/topi/python/test_topi_argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
"gpu": topi.cuda.schedule_argwhere,
}

_argwhere_compute = {
"llvm": topi.argwhere,
"cuda": topi.cuda.argwhere
}
_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}


def verify_argwhere(data_shape):
dtype = "int32"
Expand All @@ -39,27 +37,33 @@ def verify_argwhere(data_shape):
out_shape = np_out.shape[0]
np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)

out_shape = te.placeholder(shape=(out_shape, len(data_shape)),
name="out_shape", dtype=dtype)
out_shape = te.placeholder(
shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype
)
condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)

def check_device(device, ctx):
ctx = tvm.context(device, 0)
if not ctx.exist or device not in _argwhere_compute:
return

out = _argwhere_compute[device](out_shape, condition)
with tvm.target.create(device):
with tvm.target.Target(device):
out = _argwhere_compute[device](out_shape, condition)
s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule)
sch = s_func(out)

func = tvm.build(sch, [out_shape, condition, out], device,
name="argwhere")
func = tvm.build(
sch, [out_shape, condition, out], device, name="argwhere"
)

# print(func.imported_modules[0].get_source())

args = [tvm.nd.array(np_shape, ctx)]
args.append(tvm.nd.array(np_data, ctx))
args.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=condition.dtype))
func(*args)
np.set_printoptions(threshold=np.inf)
# print(args[-1].asnumpy())
tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out))

for target, ctx in tvm.testing.enabled_targets():
Expand All @@ -70,11 +74,18 @@ def check_device(device, ctx):
def test_argwhere():
verify_argwhere((1,))
verify_argwhere((100,))
verify_argwhere((1, 1))
verify_argwhere((5, 3))
verify_argwhere((100, 100))
verify_argwhere((32, 64))
# TODO(zhiics) This test is flaky because nothing is sorted.
verify_argwhere((128, 65))
verify_argwhere((6, 5, 3))
verify_argwhere((32, 32, 16))
verify_argwhere((1, 1, 1))
# TODO(zhiics) This test is flaky.
# verify_argwhere((32, 32, 8))
verify_argwhere((1, 1, 1, 1))
verify_argwhere((6, 4, 5, 3))
verify_argwhere((1, 1, 1, 1, 1))
verify_argwhere((6, 4, 5, 3, 7))


Expand Down

0 comments on commit 411eaec

Please sign in to comment.