diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 5d06f78b1776c..934f38625fd3d 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -18,9 +18,9 @@ # pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import from tvm import topi from tvm.auto_scheduler import is_auto_scheduler_enabled -from .generic import * from tvm.te import SpecializedCondition from tvm._ffi import get_global_func +from .generic import * from .. import op as _op from .cuda import judge_winograd, naive_schedule diff --git a/src/runtime/contrib/rocthrust/thrust.cc b/src/runtime/contrib/rocthrust/thrust.cc index df83b57847a06..04ea1a0d7e8ec 100644 --- a/src/runtime/contrib/rocthrust/thrust.cc +++ b/src/runtime/contrib/rocthrust/thrust.cc @@ -21,18 +21,18 @@ * \file Use external Thrust library call */ +#include #include #include -#include #include #include #include - +#include #include -#include + #include -#include #include +#include namespace tvm { namespace contrib { @@ -40,15 +40,12 @@ namespace contrib { using namespace runtime; // Performs sorting along axis -1 and returns both sorted values and indices. -template -void thrust_sort(DLTensor* input, - DLTensor* out_values, - DLTensor* out_indices, - bool is_ascend, +template +void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend, int n_values) { - thrust::device_ptr data_ptr(static_cast(input->data)); - thrust::device_ptr values_ptr(static_cast(out_values->data)); - thrust::device_ptr indices_ptr(static_cast(out_indices->data)); + thrust::device_ptr data_ptr(static_cast(input->data)); + thrust::device_ptr values_ptr(static_cast(out_values->data)); + thrust::device_ptr indices_ptr(static_cast(out_indices->data)); size_t size = 1; for (int i = 0; i < input->ndim; ++i) { @@ -85,9 +82,9 @@ void thrust_sort(DLTensor* input, auto counting_iter = thrust::counting_iterator(0); auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) { return i % n_values; - }; // NOLINT(*) - auto init_indices_iter = thrust::make_transform_iterator(counting_iter, - linear_index_to_sort_axis_index); + }; // NOLINT(*) + auto init_indices_iter = + thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index); // This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr); @@ -95,7 +92,7 @@ void thrust_sort(DLTensor* input, thrust::device_vector segment_ids(size); auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) { return i / n_values; - }; // NOLINT(*) + }; // NOLINT(*) // We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(), linear_index_to_segment_id); @@ -109,12 +106,8 @@ void thrust_sort(DLTensor* input, } } -void thrust_sort_common(DLTensor* input, - DLTensor* values_out, - DLTensor* indices_out, - bool is_ascend, - int sort_len, - std::string data_dtype, +void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out, + bool is_ascend, int sort_len, std::string data_dtype, std::string out_dtype) { if (data_dtype == "float32") { if (out_dtype == "int32") { @@ -152,7 +145,7 @@ void thrust_sort_common(DLTensor* input, } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { thrust_sort(input, values_out, indices_out, is_ascend, sort_len); } else if (out_dtype == "int64") { @@ -169,8 +162,7 @@ void thrust_sort_common(DLTensor* input, } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_GE(args.num_args, 4); DLTensor* input = args[0]; DLTensor* values_out = args[1]; @@ -181,21 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") auto out_dtype = DLDataType2String(indices_out->dtype); int n_values = input->shape[input->ndim - 1]; - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, - data_dtype, out_dtype); + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype); }); -template -void thrust_stable_sort_by_key(DLTensor* keys_in, - DLTensor* values_in, - DLTensor* keys_out, - DLTensor* values_out, - bool for_scatter) { +template +void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, + DLTensor* values_out, bool for_scatter) { const auto size = keys_in->shape[0]; - thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); - thrust::device_ptr values_in_ptr(static_cast(values_in->data)); - thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); - thrust::device_ptr values_out_ptr(static_cast(values_out->data)); + thrust::device_ptr keys_in_ptr(static_cast(keys_in->data)); + thrust::device_ptr values_in_ptr(static_cast(values_in->data)); + thrust::device_ptr keys_out_ptr(static_cast(keys_out->data)); + thrust::device_ptr values_out_ptr(static_cast(values_out->data)); if (for_scatter) { thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) { @@ -211,67 +199,65 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, } TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_GE(args.num_args, 5); - DLTensor* keys_in = args[0]; - DLTensor* values_in = args[1]; - DLTensor* keys_out = args[2]; - DLTensor* values_out = args[3]; - bool for_scatter = args[4]; - - auto key_dtype = DLDataType2String(keys_in->dtype); - auto value_dtype = DLDataType2String(values_in->dtype); - - if (key_dtype == "int32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else if (key_dtype == "int64") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + .set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_GE(args.num_args, 5); + DLTensor* keys_in = args[0]; + DLTensor* values_in = args[1]; + DLTensor* keys_out = args[2]; + DLTensor* values_out = args[3]; + bool for_scatter = args[4]; + + auto key_dtype = DLDataType2String(keys_in->dtype); + auto value_dtype = DLDataType2String(values_in->dtype); + + if (key_dtype == "int32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else if (key_dtype == "float32") { - if (value_dtype == "int32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "int64") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else if (value_dtype == "float32") { - thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, - for_scatter); - } else { - LOG(FATAL) << "Unsupported value dtype: " << value_dtype; - } - } else { - LOG(FATAL) << "Unsupported key dtype: " << key_dtype; - } -}); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "int64") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else if (key_dtype == "float32") { + if (value_dtype == "int32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else if (value_dtype == "float32") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); + } else { + LOG(FATAL) << "Unsupported value dtype: " << value_dtype; + } + } else { + LOG(FATAL) << "Unsupported key dtype: " << key_dtype; + } + }); -template -void thrust_scan(DLTensor* data, - DLTensor* output, - bool exclusive) { - thrust::device_ptr data_ptr(static_cast(data->data)); - thrust::device_ptr output_ptr(static_cast(output->data)); +template +void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) { + thrust::device_ptr data_ptr(static_cast(data->data)); + thrust::device_ptr output_ptr(static_cast(output->data)); const auto scan_size = data->shape[data->ndim - 1]; if (scan_size == 0) return; @@ -281,9 +267,8 @@ void thrust_scan(DLTensor* data, const bool need_cast = std::is_same::value == false; - auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) { - return static_cast(v); - }); // NOLINT(*) + auto data_cast_ptr = thrust::make_transform_iterator( + data_ptr, [] __host__ __device__(InType v) { return static_cast(v); }); // NOLINT(*) if (size == static_cast(data->shape[data->ndim - 1])) { if (exclusive && need_cast) { @@ -305,8 +290,8 @@ void thrust_scan(DLTensor* data, auto counting_iter = thrust::counting_iterator(0); // Without __host__ annotation, cub crashes auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) { - return i / scan_size; - }; // NOLINT(*) + return i / scan_size; + }; // NOLINT(*) auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); if (exclusive && need_cast) { @@ -321,8 +306,7 @@ void thrust_scan(DLTensor* data, } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan").set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.num_args, 3); DLTensor* data = args[0]; DLTensor* output = args[1];