Skip to content

Commit

Permalink
Use single block and thrust to fix flaky behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun authored and zhiics committed Nov 6, 2020
1 parent 411eaec commit a358f3a
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 86 deletions.
203 changes: 120 additions & 83 deletions python/tvm/topi/cuda/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,34 @@
# pylint: disable=too-many-arguments, invalid-name
"""Argwhere operator"""

import logging

import tvm
from tvm import te
from tvm._ffi import get_global_func
from .injective import schedule_injective_from_existing
from .nms import atomic_add
from .sort import topk, argsort
from .sort import topk, topk_thrust, argsort, argsort_thrust
from .. import tag
from ..transform import strided_slice, adv_index, squeeze

logger = logging.getLogger("topi")


def _get_sort_func(mode=0):
"""Get sort function for argwhere. mode 0 for topk and others for argsort."""
if get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
ret = topk_thrust if mode == 0 else argsort_thrust
else:
logger.warn("It's highly recommended to enable thrust library with set(USE_THRUST ON)"
" when compiling argwhere for cuda target. Otherwise, it can result in"
" significant performance degradation or incorrect result")
ret = topk if mode == 0 else argsort

return ret


def argwhere_1d_ir(condition, out):
"""Low level IR for argwhere 1D
Expand All @@ -48,31 +68,31 @@ def argwhere_1d_ir(condition, out):
condition = ib.buffer_ptr(condition)
out = ib.buffer_ptr(out)

valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global")
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = a0 // max_threads + 1
# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
len_inner_for = a0 // nthread_tx + 1
valid_index[0] = 0

with ib.if_scope(tid < a0):
with ib.if_scope(condition[tid] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0]] = tid
with ib.for_range(0, len_inner_for, name="i") as i:
idx = tx * len_inner_for + i
with ib.if_scope(idx < a0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0]] = idx

return ib.get()

Expand Down Expand Up @@ -111,7 +131,10 @@ def argwhere_1d(output_shape, condition):
tag="argwhere1d_gpu",
)

sorted_out = topk(
if out.shape[0] <= 1:
return out

sorted_out = _get_sort_func()(
out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32"
)

Expand All @@ -138,6 +161,8 @@ def argwhere_2d_ir(condition, out):
a0 = condition.shape[0]
a1 = condition.shape[1]

out_len = out.shape[0] * out.shape[1]

condition = ib.buffer_ptr(condition)
out = ib.buffer_ptr(out)

Expand All @@ -149,25 +174,26 @@ def argwhere_2d_ir(condition, out):
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = (a0 * a1) // max_threads + 1

# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
len_inner_for = (a0 * a1) // nthread_tx + 1

valid_index[0] = 0

with ib.if_scope(tid < (a0 * a1)):
with ib.if_scope(condition[tid] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 2] = tvm.tir.floordiv(tid, a1)
out[tmp[0] * 2 + 1] = tvm.tir.floormod(tid, a1)
with ib.for_range(0, len_inner_for, name="i") as i:
idx = tx * len_inner_for + i
with ib.if_scope(idx < (a0 * a1)):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1)
out[tmp[0] * 2 + 1] = tvm.tir.floormod(idx, a1)

return ib.get()

Expand Down Expand Up @@ -209,15 +235,17 @@ def argwhere_2d(output_shape, condition):
if out.shape[0] <= 1:
return out

sort_func = _get_sort_func(1)

# sort the output from the least significant to the most significant
# column.
out1 = strided_slice(out, [0, 1], [out.shape[0], 2])
out2 = argsort(out1, axis=0, dtype="int32")
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

out1 = strided_slice(out, [0, 0], [out.shape[0], 1])
out2 = argsort(out1, axis=0, dtype="int32")
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)

return adv_index(out, [out3])
Expand Down Expand Up @@ -257,28 +285,30 @@ def argwhere_3d_ir(condition, out):
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = s0 // max_threads + 1

# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
len_inner_for = s0 // nthread_tx + 1

fdiv = tvm.tir.floordiv
fmod = tvm.tir.floormod

valid_index[0] = 0

with ib.if_scope(tid < s0):
with ib.if_scope(condition[tid] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 3] = fdiv(tid, s1)
out[tmp[0] * 3 + 1] = fdiv(fmod(tid, s1), a2)
out[tmp[0] * 3 + 2] = fmod(tid, a2)
with ib.for_range(0, len_inner_for, name="i") as i:
idx = tx * len_inner_for + i
with ib.if_scope(idx < s0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 3] = fdiv(idx, s1)
out[tmp[0] * 3 + 1] = fdiv(fmod(idx, s1), a2)
out[tmp[0] * 3 + 2] = fmod(idx, a2)

return ib.get()

Expand Down Expand Up @@ -322,9 +352,10 @@ def argwhere_3d(output_shape, condition):

# sort the output from the least significant to the most significant
# column.
sort_func = _get_sort_func(1)
for i in reversed(range(3)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = argsort(out1, axis=0, dtype="int32")
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

Expand Down Expand Up @@ -367,29 +398,31 @@ def argwhere_4d_ir(condition, out):
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = s0 // max_threads + 1

# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
len_inner_for = s0 // nthread_tx + 1

fdiv = tvm.tir.floordiv
fmod = tvm.tir.floormod

valid_index[0] = 0

with ib.if_scope(tid < s0):
with ib.if_scope(condition[tid] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 4] = fdiv(tid, s2)
out[tmp[0] * 4 + 1] = fdiv(fmod(tid, s2), s1)
out[tmp[0] * 4 + 2] = fdiv(fmod(tid, s1), a3)
out[tmp[0] * 4 + 3] = fmod(tid, a3)
with ib.for_range(0, len_inner_for, name="i") as i:
idx = tx * len_inner_for + i
with ib.if_scope(idx < s0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 4] = fdiv(idx, s2)
out[tmp[0] * 4 + 1] = fdiv(fmod(idx, s2), s1)
out[tmp[0] * 4 + 2] = fdiv(fmod(idx, s1), a3)
out[tmp[0] * 4 + 3] = fmod(idx, a3)

return ib.get()

Expand Down Expand Up @@ -433,9 +466,10 @@ def argwhere_4d(output_shape, condition):

# sort the output from the least significant to the most significant
# column.
sort_func = _get_sort_func(1)
for i in reversed(range(4)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = argsort(out1, axis=0, dtype="int32")
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

Expand Down Expand Up @@ -480,30 +514,32 @@ def argwhere_5d_ir(condition, out):
tvm.target.Target.current(allow_none=False).max_num_threads
)
nthread_tx = max_threads
nthread_bx = s0 // max_threads + 1

# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
len_inner_for = s0 // nthread_tx + 1

fdiv = tvm.tir.floordiv
fmod = tvm.tir.floormod

valid_index[0] = 0

with ib.if_scope(tid < s0):
with ib.if_scope(condition[tid] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 5] = fdiv(tid, s3)
out[tmp[0] * 5 + 1] = fdiv(fmod(tid, s3), s2)
out[tmp[0] * 5 + 2] = fdiv(fmod(tid, s2), s1)
out[tmp[0] * 5 + 3] = fdiv(fmod(tid, s1), a4)
out[tmp[0] * 5 + 4] = fmod(tid, a4)
with ib.for_range(0, len_inner_for, name="i") as i:
idx = tx * len_inner_for + i
with ib.if_scope(idx < s0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
one_count,
)
out[tmp[0] * 5] = fdiv(idx, s3)
out[tmp[0] * 5 + 1] = fdiv(fmod(idx, s3), s2)
out[tmp[0] * 5 + 2] = fdiv(fmod(idx, s2), s1)
out[tmp[0] * 5 + 3] = fdiv(fmod(idx, s1), a4)
out[tmp[0] * 5 + 4] = fmod(idx, a4)

return ib.get()

Expand Down Expand Up @@ -547,9 +583,10 @@ def argwhere_5d(output_shape, condition):

# sort the output from the least significant to the most significant
# column.
sort_func = _get_sort_func(1)
for i in reversed(range(5)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = argsort(out1, axis=0, dtype="int32")
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,8 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8),
]

is_ascend = 1 if is_ascend else 0

out = te.extern(
[data.shape, data.shape],
[data],
Expand Down
4 changes: 1 addition & 3 deletions tests/python/topi/python/test_topi_argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,10 @@ def test_argwhere():
verify_argwhere((1, 1))
verify_argwhere((5, 3))
verify_argwhere((32, 64))
# TODO(zhiics) This test is flaky because nothing is sorted.
verify_argwhere((128, 65))
verify_argwhere((200, 500))
verify_argwhere((6, 5, 3))
verify_argwhere((1, 1, 1))
# TODO(zhiics) This test is flaky.
# verify_argwhere((32, 32, 8))
verify_argwhere((1, 1, 1, 1))
verify_argwhere((6, 4, 5, 3))
verify_argwhere((1, 1, 1, 1, 1))
Expand Down

0 comments on commit a358f3a

Please sign in to comment.