Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Nov 6, 2020
1 parent a358f3a commit ffbd2ac
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 65 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,12 +857,14 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
)
return strategy


@argwhere_strategy.register(["cuda", "gpu"])
def argwhere_strategy_cuda(attrs, inputs, out_type, target):
"""argwhere cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.cuda.argwhere),
wrap_topi_schedule(topi.cuda.schedule_argwhere),
name="argwhere.cuda")
name="argwhere.cuda",
)
return strategy
12 changes: 9 additions & 3 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,9 +1183,11 @@ def correlation_strategy(attrs, inputs, out_type, target):
)
return strategy


# argwhere
def wrap_compute_argwhere(topi_compute):
"""wrap argwhere topi compute"""

def _compute_argwhere(attrs, inputs, out_type):
output_shape = []
for s in out_type.shape:
Expand All @@ -1195,13 +1197,17 @@ def _compute_argwhere(attrs, inputs, out_type):
output_shape.append(te.var("any_dim", "int32"))
new_output_type = ir.TensorType(output_shape, "int32")
return [topi_compute(new_output_type, inputs[0])]

return _compute_argwhere


@override_native_generic_func("argwhere_strategy")
def argwhere_strategy(attrs, inputs, out_type, target):
"""argwhere generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_argwhere(topi.argwhere),
wrap_topi_schedule(topi.generic.schedule_argwhere),
name="argwhere.generic")
strategy.add_implementation(
wrap_compute_argwhere(topi.argwhere),
wrap_topi_schedule(topi.generic.schedule_argwhere),
name="argwhere.generic",
)
return strategy
78 changes: 23 additions & 55 deletions python/tvm/topi/cuda/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@

def _get_sort_func(mode=0):
"""Get sort function for argwhere. mode 0 for topk and others for argsort."""
if get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
ret = topk_thrust if mode == 0 else argsort_thrust
else:
logger.warn("It's highly recommended to enable thrust library with set(USE_THRUST ON)"
" when compiling argwhere for cuda target. Otherwise, it can result in"
" significant performance degradation or incorrect result")
logger.warn(
"It's highly recommended to enable thrust library with set(USE_THRUST ON)"
" when compiling argwhere for cuda target. Otherwise, it can result in"
" significant performance degradation or incorrect result"
)
ret = topk if mode == 0 else argsort

return ret
Expand Down Expand Up @@ -72,9 +72,7 @@ def argwhere_1d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
Expand All @@ -87,9 +85,7 @@ def argwhere_1d_ir(condition, out):
with ib.if_scope(idx < a0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
one_count,
)
out[tmp[0]] = idx
Expand All @@ -116,9 +112,7 @@ def argwhere_1d(output_shape, condition):
condition_buf = tvm.tir.decl_buffer(
condition.shape, condition.dtype, "data_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(
output_shape, "int32", "out_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)

out = te.extern(
[output_shape],
Expand Down Expand Up @@ -161,18 +155,14 @@ def argwhere_2d_ir(condition, out):
a0 = condition.shape[0]
a1 = condition.shape[1]

out_len = out.shape[0] * out.shape[1]

condition = ib.buffer_ptr(condition)
out = ib.buffer_ptr(out)

valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

# Limit threads to a single block to make sure atomic_add works normally.
Expand All @@ -187,9 +177,7 @@ def argwhere_2d_ir(condition, out):
with ib.if_scope(idx < (a0 * a1)):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
one_count,
)
out[tmp[0] * 2] = tvm.tir.floordiv(idx, a1)
Expand Down Expand Up @@ -217,9 +205,7 @@ def argwhere_2d(output_shape, condition):
condition_buf = tvm.tir.decl_buffer(
condition.shape, condition.dtype, "data_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(
output_shape, "int32", "out_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)

out = te.extern(
[output_shape],
Expand Down Expand Up @@ -281,9 +267,7 @@ def argwhere_3d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

# Limit threads to a single block to make sure atomic_add works normally.
Expand All @@ -301,9 +285,7 @@ def argwhere_3d_ir(condition, out):
with ib.if_scope(idx < s0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
one_count,
)
out[tmp[0] * 3] = fdiv(idx, s1)
Expand Down Expand Up @@ -332,9 +314,7 @@ def argwhere_3d(output_shape, condition):
condition_buf = tvm.tir.decl_buffer(
condition.shape, condition.dtype, "data_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(
output_shape, "int32", "out_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)

out = te.extern(
[output_shape],
Expand Down Expand Up @@ -394,15 +374,13 @@ def argwhere_4d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

# Limit threads to a single block to make sure atomic_add works normally.
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
len_inner_for = s0 // nthread_tx + 1
len_inner_for = s0 // nthread_tx + 1

fdiv = tvm.tir.floordiv
fmod = tvm.tir.floormod
Expand All @@ -414,9 +392,7 @@ def argwhere_4d_ir(condition, out):
with ib.if_scope(idx < s0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
one_count,
)
out[tmp[0] * 4] = fdiv(idx, s2)
Expand Down Expand Up @@ -446,9 +422,7 @@ def argwhere_4d(output_shape, condition):
condition_buf = tvm.tir.decl_buffer(
condition.shape, condition.dtype, "data_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(
output_shape, "int32", "out_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)

out = te.extern(
[output_shape],
Expand Down Expand Up @@ -510,9 +484,7 @@ def argwhere_5d_ir(condition, out):
tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
one_count = tvm.tir.const(1, dtype="int32")

max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads
)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

# Limit threads to a single block to make sure atomic_add works normally.
Expand All @@ -530,9 +502,7 @@ def argwhere_5d_ir(condition, out):
with ib.if_scope(idx < s0):
with ib.if_scope(condition[idx] != 0):
tmp[0] = atomic_add(
tvm.tir.call_intrin(
"handle", "tir.address_of", valid_index[0]
),
tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]),
one_count,
)
out[tmp[0] * 5] = fdiv(idx, s3)
Expand Down Expand Up @@ -563,9 +533,7 @@ def argwhere_5d(output_shape, condition):
condition_buf = tvm.tir.decl_buffer(
condition.shape, condition.dtype, "data_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(
output_shape, "int32", "out_buf", data_alignment=8
)
out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8)

out = te.extern(
[output_shape],
Expand Down
8 changes: 2 additions & 6 deletions tests/python/topi/python/test_topi_argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def verify_argwhere(data_shape):
out_shape = np_out.shape[0]
np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)

out_shape = te.placeholder(
shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype
)
out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)

def check_device(device, ctx):
Expand All @@ -52,9 +50,7 @@ def check_device(device, ctx):
s_func = tvm.topi.testing.dispatch(device, _argwhere_schedule)
sch = s_func(out)

func = tvm.build(
sch, [out_shape, condition, out], device, name="argwhere"
)
func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere")

# print(func.imported_modules[0].get_source())

Expand Down

0 comments on commit ffbd2ac

Please sign in to comment.