Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] GPU scatter 1D via sorting based approach #7056

Merged
merged 13 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ if(USE_CUDA)
message(STATUS "Build with Thrust support")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda")
file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)
Expand Down
106 changes: 105 additions & 1 deletion python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm import te
from ..scatter import _verify_scatter_nd_inputs
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available


def ceil_div(a, b):
Expand Down Expand Up @@ -416,6 +417,97 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
return ib.get()


def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
"""Generate scatter ir for 1d inputs, using a sorting based approach.
By sorting indices and comparing neighboring two indices, we can tell which
of elements in the indices tensor can scatter its update value into the output.
Sorting of indices, and sorting of updates with respect to indices, can be done
at the same time by thrust's sort_by_key function. It is important that sorting
be done in a "stable" way via stable_sort, to guarantee deterministic output.

Parameters
----------
data : tir.Tensor
The input data to the operator.

indices_sorted : tir.Tensor
The sorted index locations to update.

updates : tir.Tensor
The values to update, sorted by indices.

axis : int
The axis to scatter on. It must be 0 for this function.

out : tir.Tensor
The output tensor.

Returns
-------
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

with ib.new_scope():
nthread_bx = ceil_div(n, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < n):
out_ptr[tid] = data_ptr[tid]

indices_ptr = ib.buffer_ptr(indices_sorted)
updates_ptr = ib.buffer_ptr(updates_sorted)

ni = indices_sorted.shape[0]

def do_update(ib, index, update):
with ib.if_scope(index < 0):
out_ptr[index + n] = update
with ib.else_scope():
out_ptr[index] = update

with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.if_scope(tid == ni - 1):
# The last element can always update.
index = indices_ptr[tid]
update = updates_ptr[tid]
do_update(ib, index, update)

with ib.else_scope():
with ib.if_scope(tid < ni - 1):
index = indices_ptr[tid]
index_next = indices_ptr[tid + 1]

# If the next neighbor in the sorted list of indices has a different index,
# that means thread tid is the last one to have this index.
# This thread can update the output.
with ib.if_scope(index != index_next):
update = updates_ptr[tid]
do_update(ib, index, update)

return ib.get()


def scatter(data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates

Expand Down Expand Up @@ -458,9 +550,21 @@ def update_func(dst_ptr, dst_index, update):

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

in_bufs = [data]

if rank == 1 and is_thrust_available():
ir_funcs[1] = gen_scatter_1d_thrust
indices_sorted, updates_sorted = stable_sort_by_key_thrust(
indices, updates, for_scatter=True
)
in_bufs += [indices_sorted, updates_sorted]
else:
in_bufs += [indices, updates]

out = te.extern(
[out_shape],
[data, indices, updates],
in_bufs,
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
Expand Down
59 changes: 58 additions & 1 deletion python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Argsort operator """
"""Sort related operators """
import tvm
from tvm import te
from tvm._ffi import get_global_func

from .injective import schedule_injective_from_existing
from ..math import identity
Expand Down Expand Up @@ -597,3 +598,59 @@ def schedule_topk(outs):
The computation schedule for the op.
"""
return _schedule_sort(outs)


def stable_sort_by_key_thrust(keys, values, for_scatter=False):
"""Sort values with respect to keys using thrust.
Both keys and values will be sorted and returned.
Sorting is done via stable sort, so relative ordering among
ties are preserved.

Parameters
----------
keys: tvm.te.Tensor
The 1D input keys.

values : tvm.te.Tensor,
The 1D input values.

for_scatter: bool, optional
If True, negative keys are interpreted as negative indices.
Before sorting, negative indices are converted to corresponding positive indices.
The output keys (indices) are all positive.
This option is introduced to optimize the scatter implementation.

Returns
-------
keys_sorted : tvm.te.Tensor
The sorted keys

values_sorted : tvm.te.Tensor
The values sorted with respect to the keys
"""
keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8)
values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8)
out_bufs = [
tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8),
tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8),
]
out = te.extern(
[keys.shape, values.shape],
[keys, values],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.stable_sort_by_key", ins[0], ins[1], outs[0], outs[1], for_scatter
),
in_buffers=[keys_buf, values_buf],
out_buffers=out_bufs,
dtype=[keys.dtype, values.dtype],
name="stable_sort_by_key",
tag="stable_sort_by_key",
)
return out[0], out[1]


def is_thrust_available():
"""
Test if thrust based sorting ops are available.
"""
return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None
73 changes: 73 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,78 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});

template<typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in,
DLTensor* values_in,
DLTensor* keys_out,
DLTensor* values_out,
bool for_scatter) {
const auto size = keys_in->shape[0];
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType *>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType *>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType *>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType *>(values_out->data));

if (for_scatter) {
thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
} else {
thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
}
thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr);

thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr);
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* keys_in = args[0];
DLTensor* values_in = args[1];
DLTensor* keys_out = args[2];
DLTensor* values_out = args[3];
bool for_scatter = args[4];

auto key_dtype = DLDataType2String(keys_in->dtype);
auto value_dtype = DLDataType2String(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});

} // namespace contrib
} // namespace tvm
34 changes: 34 additions & 0 deletions tests/python/contrib/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import tvm.testing
from tvm import te
from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available
import numpy as np


Expand Down Expand Up @@ -90,6 +91,39 @@ def test_sort_np():
tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5)


def test_thrust_stable_sort_by_key():
if not is_thrust_available():
print("skip because thrust is not enabled...")
return

size = 6
keys = te.placeholder((size,), name="keys", dtype="int32")
values = te.placeholder((size,), name="values", dtype="int32")

keys_out, values_out = stable_sort_by_key_thrust(keys, values)

ctx = tvm.gpu(0)
target = "cuda"
s = te.create_schedule([keys_out.op, values_out.op])
f = tvm.build(s, [keys, values, keys_out, values_out], target)

keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
keys_np_out = np.zeros(keys_np.shape, np.int32)
values_np_out = np.zeros(values_np.shape, np.int32)
keys_in = tvm.nd.array(keys_np, ctx)
values_in = tvm.nd.array(values_np, ctx)
keys_out = tvm.nd.array(keys_np_out, ctx)
values_out = tvm.nd.array(values_np_out, ctx)
f(keys_in, values_in, keys_out, values_out)

ref_keys_out = np.sort(keys_np)
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)


if __name__ == "__main__":
test_sort()
test_sort_np()
test_thrust_stable_sort_by_key()