From dc3929a030ada757579e39660fd4cc336d4eb79e Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 9 Dec 2020 10:03:53 +0900 Subject: [PATCH] [TOPI] GPU scatter 1D via sorting based approach (#7056) * add thrust stable sort * rename * scatter via sort working * correctly handles negative indices * clean up, add some comments * add doc string * remove scatter benchmark stuff * add more doc * fix typo * lint fix * silence lint * fix py format * check for thrust availablity before test Co-authored-by: masa --- cmake/modules/CUDA.cmake | 1 + python/tvm/topi/cuda/scatter.py | 106 ++++++++++++++++++++++++++- python/tvm/topi/cuda/sort.py | 59 ++++++++++++++- src/runtime/contrib/thrust/thrust.cu | 73 ++++++++++++++++++ tests/python/contrib/test_sort.py | 34 +++++++++ 5 files changed, 271 insertions(+), 2 deletions(-) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 2583e8f3c9ca..3a0d56a7bb1e 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -59,6 +59,7 @@ if(USE_CUDA) message(STATUS "Build with Thrust support") cmake_minimum_required(VERSION 3.13) # to compile CUDA code enable_language(CUDA) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda") file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) endif(USE_THRUST) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 89c5cd23111b..9916e2a7fa6d 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -20,6 +20,7 @@ from tvm import te from ..scatter import _verify_scatter_nd_inputs from .nms import atomic_add +from .sort import stable_sort_by_key_thrust, is_thrust_available def ceil_div(a, b): @@ -416,6 +417,97 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): return ib.get() +def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): + """Generate scatter ir for 1d inputs, using a sorting based approach. + By sorting indices and comparing neighboring two indices, we can tell which + of elements in the indices tensor can scatter its update value into the output. + Sorting of indices, and sorting of updates with respect to indices, can be done + at the same time by thrust's sort_by_key function. It is important that sorting + be done in a "stable" way via stable_sort, to guarantee deterministic output. + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices_sorted : tir.Tensor + The sorted index locations to update. + + updates : tir.Tensor + The values to update, sorted by indices. + + axis : int + The axis to scatter on. It must be 0 for this function. + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + assert axis == 0 + n = data.shape[0] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + with ib.new_scope(): + nthread_bx = ceil_div(n, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < n): + out_ptr[tid] = data_ptr[tid] + + indices_ptr = ib.buffer_ptr(indices_sorted) + updates_ptr = ib.buffer_ptr(updates_sorted) + + ni = indices_sorted.shape[0] + + def do_update(ib, index, update): + with ib.if_scope(index < 0): + out_ptr[index + n] = update + with ib.else_scope(): + out_ptr[index] = update + + with ib.new_scope(): + nthread_bx = ceil_div(ni, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.if_scope(tid == ni - 1): + # The last element can always update. + index = indices_ptr[tid] + update = updates_ptr[tid] + do_update(ib, index, update) + + with ib.else_scope(): + with ib.if_scope(tid < ni - 1): + index = indices_ptr[tid] + index_next = indices_ptr[tid + 1] + + # If the next neighbor in the sorted list of indices has a different index, + # that means thread tid is the last one to have this index. + # This thread can update the output. + with ib.if_scope(index != index_next): + update = updates_ptr[tid] + do_update(ib, index, update) + + return ib.get() + + def scatter(data, indices, updates, axis=0): """Update data at positions defined by indices with values in updates @@ -458,9 +550,21 @@ def update_func(dst_ptr, dst_index, update): out_shape = data.shape out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + + in_bufs = [data] + + if rank == 1 and is_thrust_available(): + ir_funcs[1] = gen_scatter_1d_thrust + indices_sorted, updates_sorted = stable_sort_by_key_thrust( + indices, updates, for_scatter=True + ) + in_bufs += [indices_sorted, updates_sorted] + else: + in_bufs += [indices, updates] + out = te.extern( [out_shape], - [data, indices, updates], + in_bufs, lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), dtype=data.dtype, out_buffers=[out_buf], diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index f28d1cba096c..0094ef1adf11 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument -"""Argsort operator """ +"""Sort related operators """ import tvm from tvm import te +from tvm._ffi import get_global_func from .injective import schedule_injective_from_existing from ..math import identity @@ -597,3 +598,59 @@ def schedule_topk(outs): The computation schedule for the op. """ return _schedule_sort(outs) + + +def stable_sort_by_key_thrust(keys, values, for_scatter=False): + """Sort values with respect to keys using thrust. + Both keys and values will be sorted and returned. + Sorting is done via stable sort, so relative ordering among + ties are preserved. + + Parameters + ---------- + keys: tvm.te.Tensor + The 1D input keys. + + values : tvm.te.Tensor, + The 1D input values. + + for_scatter: bool, optional + If True, negative keys are interpreted as negative indices. + Before sorting, negative indices are converted to corresponding positive indices. + The output keys (indices) are all positive. + This option is introduced to optimize the scatter implementation. + + Returns + ------- + keys_sorted : tvm.te.Tensor + The sorted keys + + values_sorted : tvm.te.Tensor + The values sorted with respect to the keys + """ + keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8) + values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8) + out_bufs = [ + tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8), + tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8), + ] + out = te.extern( + [keys.shape, values.shape], + [keys, values], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.thrust.stable_sort_by_key", ins[0], ins[1], outs[0], outs[1], for_scatter + ), + in_buffers=[keys_buf, values_buf], + out_buffers=out_bufs, + dtype=[keys.dtype, values.dtype], + name="stable_sort_by_key", + tag="stable_sort_by_key", + ) + return out[0], out[1] + + +def is_thrust_available(): + """ + Test if thrust based sorting ops are available. + """ + return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 2054db710b6d..8ccefc5ee7d2 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -163,5 +163,78 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, data_dtype, out_dtype); }); + +template +void thrust_stable_sort_by_key(DLTensor* keys_in, + DLTensor* values_in, + DLTensor* keys_out, + DLTensor* values_out, + bool for_scatter) { + const auto size = keys_in->shape[0]; + thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); + thrust::device_ptr values_in_ptr(static_cast(values_in->data)); + thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); + thrust::device_ptr values_out_ptr(static_cast(values_out->data)); + + if (for_scatter) { + thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) { + if (k < 0) return k + static_cast(size); + return k; + }); + } else { + thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr); + } + thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr); + + thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr); +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_GE(args.num_args, 5); + DLTensor* keys_in = args[0]; + DLTensor* values_in = args[1]; + DLTensor* keys_out = args[2]; + DLTensor* values_out = args[3]; + bool for_scatter = args[4]; + + auto key_dtype = DLDataType2String(keys_in->dtype); + auto value_dtype = DLDataType2String(values_in->dtype); + + if (key_dtype == "int32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "int64") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "float32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else { + LOG(FATAL) << "Unsupported key dtype: " << key_dtype; + } +}); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index 7bd3a9cb55b8..9d6eb7cb3a1e 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available import numpy as np @@ -90,6 +91,39 @@ def test_sort_np(): tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5) +def test_thrust_stable_sort_by_key(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + size = 6 + keys = te.placeholder((size,), name="keys", dtype="int32") + values = te.placeholder((size,), name="values", dtype="int32") + + keys_out, values_out = stable_sort_by_key_thrust(keys, values) + + ctx = tvm.gpu(0) + target = "cuda" + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + if __name__ == "__main__": test_sort() test_sort_np() + test_thrust_stable_sort_by_key()