diff --git a/include/oneapi/dpl/pstl/execution_impl.h b/include/oneapi/dpl/pstl/execution_impl.h index 79ed46570bf..c6f32f17495 100644 --- a/include/oneapi/dpl/pstl/execution_impl.h +++ b/include/oneapi/dpl/pstl/execution_impl.h @@ -106,25 +106,25 @@ __select_backend(oneapi::dpl::execution::parallel_unsequenced_policy, _IteratorT namespace __ranges { -::oneapi::dpl::__internal::__serial_tag +inline ::oneapi::dpl::__internal::__serial_tag __select_backend(oneapi::dpl::execution::sequenced_policy) { return {}; } -::oneapi::dpl::__internal::__serial_tag //vectorization allowed +inline ::oneapi::dpl::__internal::__serial_tag //vectorization allowed __select_backend(oneapi::dpl::execution::unsequenced_policy) { return {}; } -::oneapi::dpl::__internal::__parallel_tag +inline ::oneapi::dpl::__internal::__parallel_tag __select_backend(oneapi::dpl::execution::parallel_policy) { return {}; } -::oneapi::dpl::__internal::__parallel_tag //vectorization allowed +inline ::oneapi::dpl::__internal::__parallel_tag //vectorization allowed __select_backend(oneapi::dpl::execution::parallel_unsequenced_policy) { return {}; diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h index 8c61bb9de0a..a220b3c29ff 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h @@ -816,11 +816,11 @@ __parallel_radix_sort(oneapi::dpl::__internal::__device_backend_tag, _ExecutionP else if (__n <= 4096 && __wg_size * 4 <= __max_wg_size) __event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 4, 16, __radix_bits, __is_ascending>{}( __exec.queue(), ::std::forward<_Range>(__in_rng), __proj); - // In __subgroup_radix_sort, we request a sub-group size via _ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED - // based upon the iters per item. For the below cases, register spills that result in runtime exceptions have - // been observed on accelerators that do not support the requested sub-group size of 16. For the above cases - // that request but may not receive a sub-group size of 16, inputs are small enough to avoid register - // spills on assessed hardware. + // In __subgroup_radix_sort, we request a sub-group size of 16 via _ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED + // for compilation targets that support this option. For the below cases, register spills that result in + // runtime exceptions have been observed on accelerators that do not support the requested sub-group size of 16. + // For the above cases that request but may not receive a sub-group size of 16, inputs are small enough to avoid + // register spills on assessed hardware. else if (__n <= 8192 && __wg_size * 8 <= __max_wg_size && __dev_has_sg16) __event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 8, 16, __radix_bits, __is_ascending>{}( __exec.queue(), ::std::forward<_Range>(__in_rng), __proj); diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h index fbf80582d43..6dd3b193a08 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort_one_wg.h @@ -30,8 +30,7 @@ template class __radix_sort_one_wg_kernel; template + std::uint32_t __radix = 4, bool __is_asc = true> struct __subgroup_radix_sort { template @@ -164,9 +163,12 @@ struct __subgroup_radix_sort auto __counter_lacc = __buf_count.get_acc(__cgh); __cgh.parallel_for<_Name...>( - __range, - ([=](sycl::nd_item<1> __it)[[_ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED(__req_sub_group_size)]] { - union __storage { _ValT __v[__block_size]; __storage(){} } __values; + __range, ([=](sycl::nd_item<1> __it) [[_ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED(16)]] { + union __storage + { + _ValT __v[__block_size]; + __storage() {} + } __values; uint16_t __wi = __it.get_local_linear_id(); uint16_t __begin_bit = 0; constexpr uint16_t __end_bit = sizeof(_KeyT) * ::std::numeric_limits::digits;