From 4caad85c8842f322c8aa473c03123ece413ffba1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 16:19:41 +0900 Subject: [PATCH 01/10] enable rocm thrust, confrimed to work on sort and scan --- cmake/modules/ROCM.cmake | 12 + python/tvm/topi/cuda/nms.py | 3 +- python/tvm/topi/cuda/scan.py | 5 +- src/runtime/contrib/rocthrust/thrust.cc | 394 ++++++++++++++++++++++++ tests/python/contrib/test_thrust.py | 127 ++++---- 5 files changed, 483 insertions(+), 58 deletions(-) create mode 100644 src/runtime/contrib/rocthrust/thrust.cc diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index ec348f8b57f6..03196c9ab0dc 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -48,6 +48,18 @@ if(USE_ROCM) list(APPEND RUNTIME_SRCS ${ROCBLAS_CONTRIB_SRCS}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY}) endif(USE_ROCBLAS) + + if(USE_THRUST) + message(STATUS "Build with Thrust support") + # Override CXX to hipcc. This is required by rocthrust + set(CMAKE_CXX_COMPILER hipcc) + find_package(rocprim REQUIRED) + find_package(rocthrust REQUIRED) + file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/rocthrust/*.cc) + list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) + list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust) + endif(USE_THRUST) + else(USE_ROCM) list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc) endif(USE_ROCM) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2d6e1e464ef8..98cb6750408a 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -610,7 +610,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): ) target = tvm.target.Target.current() - if target and target.kind.name == "cuda" and is_thrust_available(): + # TODO(masahi): Check -libs=thrust option + if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available(): sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") else: sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0bdab100b429..65d23365dc15 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -221,7 +221,7 @@ def ir(data, data_ex_scan, reduction): with ib.if_scope(scan_axis_size > 0): reduction[tid] = binop( data_ex_scan[tid * scan_axis_size + scan_axis_size - 1], - data[tid, scan_axis_size - 1], + data[tid * scan_axis_size + scan_axis_size - 1], ) with ib.else_scope(): reduction[tid] = 0 @@ -352,7 +352,8 @@ def exclusive_scan( def do_scan(data, output_dtype): target = tvm.target.Target.current() - if target and target.kind.name == "cuda" and is_thrust_available(): + # TODO(masahi): Check -libs=thrust option + if target and target.kind.name in ["cuda", "rocm"] and is_thrust_available(): return scan_thrust( data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop ) diff --git a/src/runtime/contrib/rocthrust/thrust.cc b/src/runtime/contrib/rocthrust/thrust.cc new file mode 100644 index 000000000000..df83b57847a0 --- /dev/null +++ b/src/runtime/contrib/rocthrust/thrust.cc @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external Thrust library call + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace contrib { + +using namespace runtime; + +// Performs sorting along axis -1 and returns both sorted values and indices. +template +void thrust_sort(DLTensor* input, + DLTensor* out_values, + DLTensor* out_indices, + bool is_ascend, + int n_values) { + thrust::device_ptr data_ptr(static_cast(input->data)); + thrust::device_ptr values_ptr(static_cast(out_values->data)); + thrust::device_ptr indices_ptr(static_cast(out_indices->data)); + + size_t size = 1; + for (int i = 0; i < input->ndim; ++i) { + size *= input->shape[i]; + } + thrust::copy(data_ptr, data_ptr + size, values_ptr); + + if (size == static_cast(input->shape[input->ndim - 1])) { + // A fast path for single segment case + thrust::sequence(indices_ptr, indices_ptr + n_values); + if (is_ascend) { + thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr); + } else { + thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr, + thrust::greater()); + } + } else { + // segmented sort by key + // Follow the back-to-back stable_sort_by_key strategy explained below + // https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY + thrust::device_vector argsort_order(size); + thrust::sequence(argsort_order.begin(), argsort_order.end()); + + // First, sort values and store the sorted order in argsort_order. + if (is_ascend) { + thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin()); + } else { + thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(), + thrust::greater()); + } + + // The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2 + // without materializing it + auto counting_iter = thrust::counting_iterator(0); + auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) { + return i % n_values; + }; // NOLINT(*) + auto init_indices_iter = thrust::make_transform_iterator(counting_iter, + linear_index_to_sort_axis_index); + + // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr + thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr); + + thrust::device_vector segment_ids(size); + auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) { + return i / n_values; + }; // NOLINT(*) + // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr + thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(), + linear_index_to_segment_id); + + // The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ... + // values_ptr and indices_ptr will also be sorted in the order of segmend_ids above + // Since sorting has been done in a stable way, relative orderings of values and indices + // in the segment do not change and hence they remain sorted. + auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr)); + thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip); + } +} + +void thrust_sort_common(DLTensor* input, + DLTensor* values_out, + DLTensor* indices_out, + bool is_ascend, + int sort_len, + std::string data_dtype, + std::string out_dtype) { + if (data_dtype == "float32") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "int64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float32") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float64") { + thrust_sort(input, values_out, indices_out, is_ascend, sort_len); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_GE(args.num_args, 4); + DLTensor* input = args[0]; + DLTensor* values_out = args[1]; + DLTensor* indices_out = args[2]; + bool is_ascend = args[3]; + + auto data_dtype = DLDataType2String(input->dtype); + auto out_dtype = DLDataType2String(indices_out->dtype); + + int n_values = input->shape[input->ndim - 1]; + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, + data_dtype, out_dtype); +}); + +template +void thrust_stable_sort_by_key(DLTensor* keys_in, + DLTensor* values_in, + DLTensor* keys_out, + DLTensor* values_out, + bool for_scatter) { + const auto size = keys_in->shape[0]; + thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); + thrust::device_ptr values_in_ptr(static_cast(values_in->data)); + thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); + thrust::device_ptr values_out_ptr(static_cast(values_out->data)); + + if (for_scatter) { + thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) { + if (k < 0) return k + static_cast(size); + return k; + }); + } else { + thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr); + } + thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr); + + thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr); +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_GE(args.num_args, 5); + DLTensor* keys_in = args[0]; + DLTensor* values_in = args[1]; + DLTensor* keys_out = args[2]; + DLTensor* values_out = args[3]; + bool for_scatter = args[4]; + + auto key_dtype = DLDataType2String(keys_in->dtype); + auto value_dtype = DLDataType2String(values_in->dtype); + + if (key_dtype == "int32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "int64") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "float32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else { + LOG(FATAL) << "Unsupported key dtype: " << key_dtype; + } +}); + +template +void thrust_scan(DLTensor* data, + DLTensor* output, + bool exclusive) { + thrust::device_ptr data_ptr(static_cast(data->data)); + thrust::device_ptr output_ptr(static_cast(output->data)); + const auto scan_size = data->shape[data->ndim - 1]; + + if (scan_size == 0) return; + + size_t size = 1; + for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; + + const bool need_cast = std::is_same::value == false; + + auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) { + return static_cast(v); + }); // NOLINT(*) + + if (size == static_cast(data->shape[data->ndim - 1])) { + if (exclusive && need_cast) { + thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr); + } else if (exclusive && !need_cast) { + thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } else if (!exclusive && need_cast) { + thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr); + } else { + thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } + } else { + // Use thrust segmented scan to compute scan on the inner most axis + // data->shape[0] * data->shape[1] * ... * data->shape[ndim - 2] scans are + // computed in parallel + + // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., + // without materializing the sequence vector + auto counting_iter = thrust::counting_iterator(0); + // Without __host__ annotation, cub crashes + auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) { + return i / scan_size; + }; // NOLINT(*) + auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); + + if (exclusive && need_cast) { + thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr); + } else if (exclusive && !need_cast) { + thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } else if (!exclusive && need_cast) { + thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr); + } else { + thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } + } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.num_args, 3); + DLTensor* data = args[0]; + DLTensor* output = args[1]; + bool exclusive = args[2]; + + auto in_dtype = DLDataType2String(data->dtype); + auto out_dtype = DLDataType2String(output->dtype); + + if (in_dtype == "bool") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int64, float32, and float64"; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are float32, and float64"; + } + } else if (in_dtype == "float64") { + if (out_dtype == "float64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtype is float64"; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << in_dtype + << ". Supported input dtypes are bool, int32, int64, float32, and float64"; + } +}); + +} // namespace contrib +} // namespace tvm diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py index c5b6a29d57d5..521c20de6cbd 100644 --- a/tests/python/contrib/test_thrust.py +++ b/tests/python/contrib/test_thrust.py @@ -33,25 +33,30 @@ def test_stable_sort_by_key(): keys_out, values_out = stable_sort_by_key_thrust(keys, values) - ctx = tvm.gpu(0) - target = "cuda" - s = te.create_schedule([keys_out.op, values_out.op]) - f = tvm.build(s, [keys, values, keys_out, values_out], target) - - keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) - values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) - keys_np_out = np.zeros(keys_np.shape, np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) - keys_in = tvm.nd.array(keys_np, ctx) - values_in = tvm.nd.array(values_np, ctx) - keys_out = tvm.nd.array(keys_np_out, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(keys_in, values_in, keys_out, values_out) - - ref_keys_out = np.sort(keys_np) - ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) - tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + for target in ["cuda", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue + + target += " -libs=thrust" + ctx = tvm.context(target, 0) + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) def test_exclusive_scan(): @@ -59,35 +64,41 @@ def test_exclusive_scan(): print("skip because thrust is not enabled...") return - for ishape in [(10,), (10, 10), (10, 10, 10)]: - values = te.placeholder(ishape, name="values", dtype="int32") + for target in ["cuda", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue - with tvm.target.Target("cuda"): - scan, reduction = exclusive_scan(values, return_reduction=True) - s = schedule_scan([scan, reduction]) + target += " -libs=thrust" + for ishape in [(10,), (10, 10), (10, 10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") - ctx = tvm.gpu(0) - f = tvm.build(s, [values, scan, reduction], "cuda") + with tvm.target.Target(target): + scan, reduction = exclusive_scan(values, return_reduction=True) + s = schedule_scan([scan, reduction]) - values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) + ctx = tvm.context(target, 0) + f = tvm.build(s, [values, scan, reduction], target) - if len(ishape) == 1: - reduction_shape = () - else: - reduction_shape = ishape[:-1] + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) - reduction_np_out = np.zeros(reduction_shape, np.int32) + if len(ishape) == 1: + reduction_shape = () + else: + reduction_shape = ishape[:-1] - values_in = tvm.nd.array(values_np, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - reduction_out = tvm.nd.array(reduction_np_out, ctx) - f(values_in, values_out, reduction_out) + reduction_np_out = np.zeros(reduction_shape, np.int32) - ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) - ref_reduction_out = np.sum(values_np, axis=-1) - tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + reduction_out = tvm.nd.array(reduction_np_out, ctx) + f(values_in, values_out, reduction_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + ref_reduction_out = np.sum(values_np, axis=-1) + tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) def test_inclusive_scan(): @@ -97,24 +108,30 @@ def test_inclusive_scan(): out_dtype = "int64" - for ishape in [(10,), (10, 10)]: - values = te.placeholder(ishape, name="values", dtype="int32") + for target in ["cuda", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue - with tvm.target.Target("cuda"): - scan = scan_thrust(values, out_dtype, exclusive=False) - s = tvm.te.create_schedule([scan.op]) + target += " -libs=thrust" + for ishape in [(10,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") - ctx = tvm.gpu(0) - f = tvm.build(s, [values, scan], "cuda") + with tvm.target.Target(target): + scan = scan_thrust(values, out_dtype, exclusive=False) + s = tvm.te.create_schedule([scan.op]) - values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, out_dtype) - values_in = tvm.nd.array(values_np, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(values_in, values_out) + ctx = tvm.context(target, 0) + f = tvm.build(s, [values, scan], target) - ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, out_dtype) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(values_in, values_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) if __name__ == "__main__": From b30d04687bc353ff93d992fdee6da50827eeaeb3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 16:20:55 +0900 Subject: [PATCH 02/10] add rocm argsort strategy --- python/tvm/relay/op/strategy/rocm.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index c52da541a8ab..ae292c82c25c 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -19,6 +19,7 @@ from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled from .generic import * +from tvm._ffi import get_global_func from .. import op as _op from .cuda import judge_winograd, naive_schedule @@ -219,3 +220,56 @@ def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): plevel=12, ) return strategy + + +def can_use_thrust(target, func_name): + return ( + target.kind.name == "rocm" + and "thrust" in target.libs + and get_global_func(func_name, allow_missing=True) + ) + + +@argsort_strategy.register(["rocm"]) +def argsort_strategy_cuda(attrs, inputs, out_type, target): + """argsort rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_argsort(topi.cuda.argsort), + wrap_topi_schedule(topi.cuda.schedule_argsort), + name="argsort.rocm", + ) + if can_use_thrust(target, "tvm.contrib.thrust.sort"): + strategy.add_implementation( + wrap_compute_argsort(topi.cuda.argsort_thrust), + wrap_topi_schedule(topi.cuda.schedule_argsort), + name="argsort_thrust.rocm", + plevel=15, + ) + return strategy + + +# @scatter_strategy.register(["rocm"]) +# def scatter_cuda(attrs, inputs, out_type, target): +# """scatter rocm strategy""" +# strategy = _op.OpStrategy() +# strategy.add_implementation( +# wrap_compute_scatter(topi.cuda.scatter), +# wrap_topi_schedule(topi.cuda.schedule_scatter), +# name="scatter.cuda", +# plevel=10, +# ) + +# rank = len(inputs[0].shape) + +# with SpecializedCondition(rank == 1): +# if target.kind.name == "rocm" and get_global_func( +# "tvm.contrib.thrust.stable_sort_by_key", allow_missing=True +# ): +# strategy.add_implementation( +# wrap_compute_scatter(topi.cuda.scatter_via_sort), +# wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), +# name="scatter_via_sort.rocm", +# plevel=9, # use the sequential version by default +# ) +# return strategy From 2e0cefb24ea12d61fa52649caee94b269f636fcf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 16:56:50 +0900 Subject: [PATCH 03/10] Abort if CXX is not hipcc --- cmake/modules/ROCM.cmake | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index 03196c9ab0dc..57b2aaae0359 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -50,9 +50,14 @@ if(USE_ROCM) endif(USE_ROCBLAS) if(USE_THRUST) - message(STATUS "Build with Thrust support") - # Override CXX to hipcc. This is required by rocthrust - set(CMAKE_CXX_COMPILER hipcc) + message(STATUS "Build with rocThrust support") + # We need to override CXX to hipcc. This is required by rocthrust + if (${CMAKE_CXX_COMPILER} MATCHES "hipcc$") + message(STATUS "Using hipcc compiler to compile rocthrust code.") + else() + message(FATAL_ERROR "Set CXX=hipcc to compile rocthrust code.") + endif() + find_package(rocprim REQUIRED) find_package(rocthrust REQUIRED) file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/rocthrust/*.cc) From 0a01d1ba63f1293bc60343dee6a09be76b8eccfe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 17:01:43 +0900 Subject: [PATCH 04/10] add more strategy --- python/tvm/relay/op/strategy/rocm.py | 81 ++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index ae292c82c25c..f7780eafbc17 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -249,27 +249,64 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): return strategy -# @scatter_strategy.register(["rocm"]) -# def scatter_cuda(attrs, inputs, out_type, target): -# """scatter rocm strategy""" -# strategy = _op.OpStrategy() -# strategy.add_implementation( -# wrap_compute_scatter(topi.cuda.scatter), -# wrap_topi_schedule(topi.cuda.schedule_scatter), -# name="scatter.cuda", -# plevel=10, -# ) +@scatter_strategy.register(["rocm"]) +def scatter_cuda(attrs, inputs, out_type, target): + """scatter rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter), + wrap_topi_schedule(topi.cuda.schedule_scatter), + name="scatter.rocm", + plevel=10, + ) + + rank = len(inputs[0].shape) + + with SpecializedCondition(rank == 1): + if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"): + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter_via_sort), + wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), + name="scatter_via_sort.rocm", + plevel=9, # use the sequential version by default + ) + return strategy + + +@sort_strategy.register(["rocm"]) +def sort_strategy_cuda(attrs, inputs, out_type, target): + """sort rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sort(topi.cuda.sort), + wrap_topi_schedule(topi.cuda.schedule_sort), + name="sort.rocm", + ) + if can_use_thrust(target, "tvm.contrib.thrust.sort"): + strategy.add_implementation( + wrap_compute_sort(topi.cuda.sort_thrust), + wrap_topi_schedule(topi.cuda.schedule_sort), + name="sort_thrust.cuda", + plevel=15, + ) + return strategy + -# rank = len(inputs[0].shape) +@topk_strategy.register(["rocm"]) +def topk_strategy_cuda(attrs, inputs, out_type, target): + """topk rocm strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_topk(topi.cuda.topk), + wrap_topi_schedule(topi.cuda.schedule_topk), + name="topk.rocm", + ) -# with SpecializedCondition(rank == 1): -# if target.kind.name == "rocm" and get_global_func( -# "tvm.contrib.thrust.stable_sort_by_key", allow_missing=True -# ): -# strategy.add_implementation( -# wrap_compute_scatter(topi.cuda.scatter_via_sort), -# wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), -# name="scatter_via_sort.rocm", -# plevel=9, # use the sequential version by default -# ) -# return strategy + if can_use_thrust(target, "tvm.contrib.thrust.sort"): + strategy.add_implementation( + wrap_compute_topk(topi.cuda.topk_thrust), + wrap_topi_schedule(topi.cuda.schedule_topk), + name="topk_thrust.rocm", + plevel=15, + ) + return strategy From a6e6d392436eb825a20f6b83b4eb5568d950c04c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 17:13:01 +0900 Subject: [PATCH 05/10] add missing import --- python/tvm/relay/op/strategy/rocm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index f7780eafbc17..5d06f78b1776 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -19,6 +19,7 @@ from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled from .generic import * +from tvm.te import SpecializedCondition from tvm._ffi import get_global_func from .. import op as _op from .cuda import judge_winograd, naive_schedule From 27001058bf1fc3fd968b8d1d6e215059d4ec93c0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 17:34:40 +0900 Subject: [PATCH 06/10] fix lint --- python/tvm/relay/op/strategy/rocm.py | 2 +- src/runtime/contrib/rocthrust/thrust.cc | 186 +++++++++++------------- 2 files changed, 86 insertions(+), 102 deletions(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 5d06f78b1776..934f38625fd3 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -18,9 +18,9 @@ # pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled -from .generic import * from tvm.te import SpecializedCondition from tvm._ffi import get_global_func +from .generic import * from .. import op as _op from .cuda import judge_winograd, naive_schedule diff --git a/src/runtime/contrib/rocthrust/thrust.cc b/src/runtime/contrib/rocthrust/thrust.cc index df83b57847a0..04ea1a0d7e8e 100644 --- a/src/runtime/contrib/rocthrust/thrust.cc +++ b/src/runtime/contrib/rocthrust/thrust.cc @@ -21,18 +21,18 @@ * \file Use external Thrust library call */ +#include #include #include -#include #include #include #include - +#include #include -#include + #include -#include #include +#include namespace tvm { namespace contrib { @@ -40,15 +40,12 @@ namespace contrib { using namespace runtime; // Performs sorting along axis -1 and returns both sorted values and indices. -template -void thrust_sort(DLTensor* input, - DLTensor* out_values, - DLTensor* out_indices, - bool is_ascend, +template +void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend, int n_values) { - thrust::device_ptr data_ptr(static_cast(input->data)); - thrust::device_ptr values_ptr(static_cast(out_values->data)); - thrust::device_ptr indices_ptr(static_cast(out_indices->data)); + thrust::device_ptr data_ptr(static_cast(input->data)); + thrust::device_ptr values_ptr(static_cast(out_values->data)); + thrust::device_ptr indices_ptr(static_cast(out_indices->data)); size_t size = 1; for (int i = 0; i < input->ndim; ++i) { @@ -85,9 +82,9 @@ void thrust_sort(DLTensor* input, auto counting_iter = thrust::counting_iterator(0); auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) { return i % n_values; - }; // NOLINT(*) - auto init_indices_iter = thrust::make_transform_iterator(counting_iter, - linear_index_to_sort_axis_index); + }; // NOLINT(*) + auto init_indices_iter = + thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index); // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr); @@ -95,7 +92,7 @@ void thrust_sort(DLTensor* input, thrust::device_vector segment_ids(size); auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) { return i / n_values; - }; // NOLINT(*) + }; // NOLINT(*) // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(), linear_index_to_segment_id); @@ -109,12 +106,8 @@ void thrust_sort(DLTensor* input, } } -void thrust_sort_common(DLTensor* input, - DLTensor* values_out, - DLTensor* indices_out, - bool is_ascend, - int sort_len, - std::string data_dtype, +void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out, + bool is_ascend, int sort_len, std::string data_dtype, std::string out_dtype) { if (data_dtype == "float32") { if (out_dtype == "int32") { @@ -152,7 +145,7 @@ void thrust_sort_common(DLTensor* input, } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "int64") { @@ -169,8 +162,7 @@ void thrust_sort_common(DLTensor* input, } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_GE(args.num_args, 4); DLTensor* input = args[0]; DLTensor* values_out = args[1]; @@ -181,21 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") auto out_dtype = DLDataType2String(indices_out->dtype); int n_values = input->shape[input->ndim - 1]; - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, - data_dtype, out_dtype); + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype); }); -template -void thrust_stable_sort_by_key(DLTensor* keys_in, - DLTensor* values_in, - DLTensor* keys_out, - DLTensor* values_out, - bool for_scatter) { +template +void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, + DLTensor* values_out, bool for_scatter) { const auto size = keys_in->shape[0]; - thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); - thrust::device_ptr values_in_ptr(static_cast(values_in->data)); - thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); - thrust::device_ptr values_out_ptr(static_cast(values_out->data)); + thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); + thrust::device_ptr values_in_ptr(static_cast(values_in->data)); + thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); + thrust::device_ptr values_out_ptr(static_cast(values_out->data)); if (for_scatter) { thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) { @@ -211,67 +199,65 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, } TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_GE(args.num_args, 5); - DLTensor* keys_in = args[0]; - DLTensor* values_in = args[1]; - DLTensor* keys_out = args[2]; - DLTensor* values_out = args[3]; - bool for_scatter = args[4]; - - auto key_dtype = DLDataType2String(keys_in->dtype); - auto value_dtype = DLDataType2String(values_in->dtype); - - if (key_dtype == "int32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else if (key_dtype == "int64") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + .set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_GE(args.num_args, 5); + DLTensor* keys_in = args[0]; + DLTensor* values_in = args[1]; + DLTensor* keys_out = args[2]; + DLTensor* values_out = args[3]; + bool for_scatter = args[4]; + + auto key_dtype = DLDataType2String(keys_in->dtype); + auto value_dtype = DLDataType2String(values_in->dtype); + + if (key_dtype == "int32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else if (key_dtype == "float32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else { - LOG(FATAL) << "Unsupported key dtype: " << key_dtype; - } -}); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "int64") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "float32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else { + LOG(FATAL) << "Unsupported key dtype: " << key_dtype; + } + }); -template -void thrust_scan(DLTensor* data, - DLTensor* output, - bool exclusive) { - thrust::device_ptr data_ptr(static_cast(data->data)); - thrust::device_ptr output_ptr(static_cast(output->data)); +template +void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) { + thrust::device_ptr data_ptr(static_cast(data->data)); + thrust::device_ptr output_ptr(static_cast(output->data)); const auto scan_size = data->shape[data->ndim - 1]; if (scan_size == 0) return; @@ -281,9 +267,8 @@ void thrust_scan(DLTensor* data, const bool need_cast = std::is_same::value == false; - auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) { - return static_cast(v); - }); // NOLINT(*) + auto data_cast_ptr = thrust::make_transform_iterator( + data_ptr, [] __host__ __device__(InType v) { return static_cast(v); }); // NOLINT(*) if (size == static_cast(data->shape[data->ndim - 1])) { if (exclusive && need_cast) { @@ -305,8 +290,8 @@ void thrust_scan(DLTensor* data, auto counting_iter = thrust::counting_iterator(0); // Without __host__ annotation, cub crashes auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) { - return i / scan_size; - }; // NOLINT(*) + return i / scan_size; + }; // NOLINT(*) auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); if (exclusive && need_cast) { @@ -321,8 +306,7 @@ void thrust_scan(DLTensor* data, } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan").set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.num_args, 3); DLTensor* data = args[0]; DLTensor* output = args[1]; From f9e60a8c44fd637e0508da45fb8aea3840f6e84b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 16 Feb 2021 19:53:09 +0900 Subject: [PATCH 07/10] show supported data type in err msg --- src/runtime/contrib/rocthrust/thrust.cc | 27 ++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/runtime/contrib/rocthrust/thrust.cc b/src/runtime/contrib/rocthrust/thrust.cc index 04ea1a0d7e8e..a8028ff2d2f5 100644 --- a/src/runtime/contrib/rocthrust/thrust.cc +++ b/src/runtime/contrib/rocthrust/thrust.cc @@ -119,7 +119,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } else if (out_dtype == "float64") { thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32 and float64"; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { @@ -131,7 +132,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } else if (out_dtype == "float64") { thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32 and float64"; } } else if (data_dtype == "int32") { if (out_dtype == "int32") { @@ -143,7 +145,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } else if (out_dtype == "float64") { thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32 and float64"; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { @@ -155,10 +158,12 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } else if (out_dtype == "float64") { thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32 and float64"; } } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + LOG(FATAL) << "Unsupported input dtype: " << data_dtype + << ". Supported input dtypes are int32, int64, and float32 and float64."; } } @@ -221,7 +226,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + LOG(FATAL) << "Unsupported value dtype: " << value_dtype + << ". Supported value dtypes are int32, int64 and float32"; } } else if (key_dtype == "int64") { if (value_dtype == "int32") { @@ -234,7 +240,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + LOG(FATAL) << "Unsupported value dtype: " << value_dtype + << ". Supported value dtypes are int32, int64 and float32"; } } else if (key_dtype == "float32") { if (value_dtype == "int32") { @@ -247,10 +254,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + LOG(FATAL) << "Unsupported value dtype: " << value_dtype + << ". Supported value dtypes are int32, int64 and float32"; } } else { - LOG(FATAL) << "Unsupported key dtype: " << key_dtype; + LOG(FATAL) << "Unsupported key dtype: " << key_dtype + << ". Supported key dtypes are int32, int64, and float32."; } }); From 684ecc3e89b2cee0911f9b563392252518cbb559 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Feb 2021 08:55:05 +0900 Subject: [PATCH 08/10] try remove rocthrust --- cmake/modules/ROCM.cmake | 3 +- src/runtime/contrib/rocthrust/thrust.cc | 387 ------------------------ 2 files changed, 2 insertions(+), 388 deletions(-) delete mode 100644 src/runtime/contrib/rocthrust/thrust.cc diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index 57b2aaae0359..b22e9f2532b7 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -60,7 +60,8 @@ if(USE_ROCM) find_package(rocprim REQUIRED) find_package(rocthrust REQUIRED) - file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/rocthrust/*.cc) + set_source_files_properties(src/runtime/contrib/thrust/thrust.cu PROPERTIES LANGUAGE CXX) + file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/thrust.cu) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust) endif(USE_THRUST) diff --git a/src/runtime/contrib/rocthrust/thrust.cc b/src/runtime/contrib/rocthrust/thrust.cc deleted file mode 100644 index a8028ff2d2f5..000000000000 --- a/src/runtime/contrib/rocthrust/thrust.cc +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file Use external Thrust library call - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace contrib { - -using namespace runtime; - -// Performs sorting along axis -1 and returns both sorted values and indices. -template -void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend, - int n_values) { - thrust::device_ptr data_ptr(static_cast(input->data)); - thrust::device_ptr values_ptr(static_cast(out_values->data)); - thrust::device_ptr indices_ptr(static_cast(out_indices->data)); - - size_t size = 1; - for (int i = 0; i < input->ndim; ++i) { - size *= input->shape[i]; - } - thrust::copy(data_ptr, data_ptr + size, values_ptr); - - if (size == static_cast(input->shape[input->ndim - 1])) { - // A fast path for single segment case - thrust::sequence(indices_ptr, indices_ptr + n_values); - if (is_ascend) { - thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr); - } else { - thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr, - thrust::greater()); - } - } else { - // segmented sort by key - // Follow the back-to-back stable_sort_by_key strategy explained below - // https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY - thrust::device_vector argsort_order(size); - thrust::sequence(argsort_order.begin(), argsort_order.end()); - - // First, sort values and store the sorted order in argsort_order. - if (is_ascend) { - thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin()); - } else { - thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(), - thrust::greater()); - } - - // The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2 - // without materializing it - auto counting_iter = thrust::counting_iterator(0); - auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) { - return i % n_values; - }; // NOLINT(*) - auto init_indices_iter = - thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index); - - // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr - thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr); - - thrust::device_vector segment_ids(size); - auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) { - return i / n_values; - }; // NOLINT(*) - // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr - thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(), - linear_index_to_segment_id); - - // The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ... - // values_ptr and indices_ptr will also be sorted in the order of segmend_ids above - // Since sorting has been done in a stable way, relative orderings of values and indices - // in the segment do not change and hence they remain sorted. - auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr)); - thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip); - } -} - -void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out, - bool is_ascend, int sort_len, std::string data_dtype, - std::string out_dtype) { - if (data_dtype == "float32") { - if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32 and float64"; - } - } else if (data_dtype == "float64") { - if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32 and float64"; - } - } else if (data_dtype == "int32") { - if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32 and float64"; - } - } else if (data_dtype == "int64") { - if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend, sort_len); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32 and float64"; - } - } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype - << ". Supported input dtypes are int32, int64, and float32 and float64."; - } -} - -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_GE(args.num_args, 4); - DLTensor* input = args[0]; - DLTensor* values_out = args[1]; - DLTensor* indices_out = args[2]; - bool is_ascend = args[3]; - - auto data_dtype = DLDataType2String(input->dtype); - auto out_dtype = DLDataType2String(indices_out->dtype); - - int n_values = input->shape[input->ndim - 1]; - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype); -}); - -template -void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, - DLTensor* values_out, bool for_scatter) { - const auto size = keys_in->shape[0]; - thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); - thrust::device_ptr values_in_ptr(static_cast(values_in->data)); - thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); - thrust::device_ptr values_out_ptr(static_cast(values_out->data)); - - if (for_scatter) { - thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) { - if (k < 0) return k + static_cast(size); - return k; - }); - } else { - thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr); - } - thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr); - - thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr); -} - -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_GE(args.num_args, 5); - DLTensor* keys_in = args[0]; - DLTensor* values_in = args[1]; - DLTensor* keys_out = args[2]; - DLTensor* values_out = args[3]; - bool for_scatter = args[4]; - - auto key_dtype = DLDataType2String(keys_in->dtype); - auto value_dtype = DLDataType2String(values_in->dtype); - - if (key_dtype == "int32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype - << ". Supported value dtypes are int32, int64 and float32"; - } - } else if (key_dtype == "int64") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype - << ". Supported value dtypes are int32, int64 and float32"; - } - } else if (key_dtype == "float32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype - << ". Supported value dtypes are int32, int64 and float32"; - } - } else { - LOG(FATAL) << "Unsupported key dtype: " << key_dtype - << ". Supported key dtypes are int32, int64, and float32."; - } - }); - -template -void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) { - thrust::device_ptr data_ptr(static_cast(data->data)); - thrust::device_ptr output_ptr(static_cast(output->data)); - const auto scan_size = data->shape[data->ndim - 1]; - - if (scan_size == 0) return; - - size_t size = 1; - for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; - - const bool need_cast = std::is_same::value == false; - - auto data_cast_ptr = thrust::make_transform_iterator( - data_ptr, [] __host__ __device__(InType v) { return static_cast(v); }); // NOLINT(*) - - if (size == static_cast(data->shape[data->ndim - 1])) { - if (exclusive && need_cast) { - thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr); - } else if (exclusive && !need_cast) { - thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); - } else if (!exclusive && need_cast) { - thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr); - } else { - thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); - } - } else { - // Use thrust segmented scan to compute scan on the inner most axis - // data->shape[0] * data->shape[1] * ... * data->shape[ndim - 2] scans are - // computed in parallel - - // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., - // without materializing the sequence vector - auto counting_iter = thrust::counting_iterator(0); - // Without __host__ annotation, cub crashes - auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) { - return i / scan_size; - }; // NOLINT(*) - auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); - - if (exclusive && need_cast) { - thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr); - } else if (exclusive && !need_cast) { - thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); - } else if (!exclusive && need_cast) { - thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr); - } else { - thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); - } - } -} - -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.num_args, 3); - DLTensor* data = args[0]; - DLTensor* output = args[1]; - bool exclusive = args[2]; - - auto in_dtype = DLDataType2String(data->dtype); - auto out_dtype = DLDataType2String(output->dtype); - - if (in_dtype == "bool") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int32") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int64") { - if (out_dtype == "int64") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int64, float32, and float64"; - } - } else if (in_dtype == "float32") { - if (out_dtype == "float32") { - thrust_scan(data, output, exclusive); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are float32, and float64"; - } - } else if (in_dtype == "float64") { - if (out_dtype == "float64") { - thrust_scan(data, output, exclusive); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtype is float64"; - } - } else { - LOG(FATAL) << "Unsupported input dtype: " << in_dtype - << ". Supported input dtypes are bool, int32, int64, float32, and float64"; - } -}); - -} // namespace contrib -} // namespace tvm From 3a11204a7c7df2ac6c120f89d502df856c5539cb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Feb 2021 09:06:51 +0900 Subject: [PATCH 09/10] add missing include for rocthrust --- cmake/modules/ROCM.cmake | 2 +- src/runtime/contrib/thrust/thrust.cu | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index b22e9f2532b7..ca8682392191 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -61,7 +61,7 @@ if(USE_ROCM) find_package(rocprim REQUIRED) find_package(rocthrust REQUIRED) set_source_files_properties(src/runtime/contrib/thrust/thrust.cu PROPERTIES LANGUAGE CXX) - file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/thrust.cu) + file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust) endif(USE_THRUST) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 7295d4c47c3f..df83b57847a0 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -26,6 +26,7 @@ #include #include #include +#include #include #include From 2ee3d61c8d369396e8d97dda37ee7a144c349559 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Feb 2021 09:15:24 +0900 Subject: [PATCH 10/10] more minor change --- cmake/modules/ROCM.cmake | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index ca8682392191..b908df2f869b 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -61,8 +61,7 @@ if(USE_ROCM) find_package(rocprim REQUIRED) find_package(rocthrust REQUIRED) set_source_files_properties(src/runtime/contrib/thrust/thrust.cu PROPERTIES LANGUAGE CXX) - file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) - list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) + list(APPEND RUNTIME_SRCS src/runtime/contrib/thrust/thrust.cu) list(APPEND TVM_RUNTIME_LINKER_LIBS roc::rocthrust) endif(USE_THRUST)