Skip to content

Commit

Permalink
[SYCL] Generalize group_algorithm helpers (#12726)
Browse files Browse the repository at this point in the history
This commit generalizes two helper functions in group_algorithm.hpp to
make it so they can also handle non-uniform groups.

---------

Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com>
  • Loading branch information
steffenlarsen authored Feb 26, 2024
1 parent c90de3c commit 77a25db
Showing 1 changed file with 13 additions and 30 deletions.
43 changes: 13 additions & 30 deletions sycl/include/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,42 +59,25 @@ template <> inline id<3> linear_id_to_id(range<3> r, size_t linear_id) {
}

// ---- get_local_linear_range
template <typename Group> size_t get_local_linear_range(Group g);
template <> inline size_t get_local_linear_range<group<1>>(group<1> g) {
return g.get_local_range(0);
}
template <> inline size_t get_local_linear_range<group<2>>(group<2> g) {
return g.get_local_range(0) * g.get_local_range(1);
}
template <> inline size_t get_local_linear_range<group<3>>(group<3> g) {
return g.get_local_range(0) * g.get_local_range(1) * g.get_local_range(2);
}
template <>
inline size_t get_local_linear_range<sycl::sub_group>(sycl::sub_group g) {
return g.get_local_range()[0];
template <typename Group> inline auto get_local_linear_range(Group g) {
auto local_range = g.get_local_range();
auto result = local_range[0];
for (size_t i = 1; i < Group::dimensions; ++i)
result *= local_range[i];
return result;
}

// ---- get_local_linear_id
template <typename Group>
inline typename Group::linear_id_type get_local_linear_id(Group g);

template <typename Group> inline auto get_local_linear_id(Group g) {
#ifdef __SYCL_DEVICE_ONLY__
#define __SYCL_GROUP_GET_LOCAL_LINEAR_ID(D) \
template <> \
inline group<D>::linear_id_type get_local_linear_id<group<D>>(group<D>) { \
nd_item<D> it = sycl::detail::Builder::getNDItem<D>(); \
return it.get_local_linear_id(); \
if constexpr (std::is_same_v<Group, group<1>> ||
std::is_same_v<Group, group<2>> ||
std::is_same_v<Group, group<3>>) {
auto it = sycl::detail::Builder::getNDItem<Group::dimensions>();
return it.get_local_linear_id();
}
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(1);
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(2);
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(3);
#undef __SYCL_GROUP_GET_LOCAL_LINEAR_ID
#endif // __SYCL_DEVICE_ONLY__

template <>
inline sycl::sub_group::linear_id_type
get_local_linear_id<sycl::sub_group>(sycl::sub_group g) {
return g.get_local_id()[0];
return g.get_local_linear_id();
}

// ---- is_native_op
Expand Down

0 comments on commit 77a25db

Please sign in to comment.