Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Add support for key/value sorting APIs #13942

Merged
merged 8 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/group_sort_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
typename ValsT, typename GroupT>
void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
const uint32_t last_iter, KeysT *keys,
ValsT vals, const ScratchMemory &memory) {
ValsT *vals, const ScratchMemory &memory) {
const uint32_t radix_states = getStatesInBits(radix_bits);
const size_t wgsize = group.get_local_linear_range();
const size_t idx = group.get_local_linear_id();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ template <typename CompareT = std::less<>> class joint_sorter {
}

template <typename T>
static constexpr size_t memory_required(sycl::memory_scope,
size_t range_size) {
static size_t memory_required(sycl::memory_scope, size_t range_size) {
return range_size * sizeof(T) + alignof(T);
}
};
Expand Down Expand Up @@ -336,13 +335,47 @@ class group_sorter {
return val;
}

static constexpr std::size_t memory_required(sycl::memory_scope scope,
size_t range_size) {
static std::size_t memory_required(sycl::memory_scope scope,
size_t range_size) {
return 2 * joint_sorter<>::template memory_required<T>(
scope, range_size * ElementsPerWorkItem);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is regarding https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp#L252

It's expected for joint_sorter (default and radix sorters both) to support oneDPL zip_iterator (https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp#L252 : Example 3).
Do you plan to add this support (no matter in this PR or a separate one)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into this but not as part of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. Thanks

};

template <typename KeyTy, typename ValueTy, typename CompareT = std::less<>,
std::size_t ElementsPerWorkItem = 1>
class group_key_value_sorter {
CompareT comp;
sycl::span<std::byte> scratch;

public:
template <std::size_t Extent>
group_key_value_sorter(sycl::span<std::byte, Extent> scratch_,
CompareT comp_ = {})
: comp(comp_), scratch(scratch_) {}

template <typename Group>
std::tuple<KeyTy, ValueTy> operator()(Group g, KeyTy key, ValueTy value) {
static_assert(ElementsPerWorkItem == 1,
"ElementsPerWorkItem must be equal 1");

using KeyValue = std::tuple<KeyTy, ValueTy>;
auto comp_key_value = [this_comp = this->comp](const KeyValue &lhs,
const KeyValue &rhs) {
return this_comp(std::get<0>(lhs), std::get<0>(rhs));
};
return group_sorter<KeyValue, decltype(comp_key_value),
ElementsPerWorkItem>(scratch, comp_key_value)(
g, KeyValue(key, value));
}

static std::size_t memory_required(sycl::memory_scope scope,
std::size_t range_size) {
return group_sorter<std::tuple<KeyTy, ValueTy>, CompareT,
ElementsPerWorkItem>::memory_required(scope,
range_size);
}
};
} // namespace default_sorters

// Radix sorters provided by the second version of the extension specification.
Expand Down Expand Up @@ -455,6 +488,57 @@ class group_sorter {
}
};

template <typename KeyTy, typename ValueTy,
sorting_order Order = sorting_order::ascending,
size_t ElementsPerWorkItem = 1, unsigned int BitsPerPass = 4>
class group_key_value_sorter {
sycl::span<std::byte> scratch;
uint32_t first_bit;
uint32_t last_bit;

static constexpr uint32_t bits = BitsPerPass;
using bitset_t = std::bitset<sizeof(KeyTy) * CHAR_BIT>;

public:
template <std::size_t Extent>
group_key_value_sorter(sycl::span<std::byte, Extent> scratch_,
const bitset_t mask = bitset_t{}.set())
: scratch(scratch_) {
static_assert((std::is_arithmetic<KeyTy>::value ||
std::is_same<KeyTy, sycl::half>::value),
"radix sort is not usable");
for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
++first_bit)
;
for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
++last_bit)
;
}

template <typename Group>
std::tuple<KeyTy, ValueTy> operator()([[maybe_unused]] Group g, KeyTy key,
ValueTy val) {
static_assert(ElementsPerWorkItem == 1, "ElementsPerWorkItem must be 1");
KeyTy key_result[]{key};
ValueTy val_result[]{val};
#ifdef __SYCL_DEVICE_ONLY__
sycl::detail::privateStaticSort<
/*is_key_value=*/true,
/*is_blocked=*/true, Order == sorting_order::ascending, 1, bits>(
g, key_result, val_result, scratch.data(), first_bit, last_bit);
#endif
key = key_result[0];
val = val_result[0];
return {key, val};
}

static constexpr std::size_t memory_required(sycl::memory_scope,
std::size_t range_size) {
return (std::max)(range_size * ElementsPerWorkItem *
(sizeof(KeyTy) + sizeof(ValueTy)),
range_size * (1 << bits) * sizeof(uint32_t));
}
};
} // namespace radix_sorters

} // namespace ext::oneapi::experimental
Expand Down
55 changes: 55 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/group_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ struct is_sorter_impl<Sorter, Group, Ptr,
template <typename Sorter, typename Group, typename ValOrPtr>
struct is_sorter : decltype(is_sorter_impl<Sorter, Group, ValOrPtr>::test(0)) {
};

template <typename Sorter, typename Group, typename Key, typename Value,
typename = void>
struct is_key_value_sorter : std::false_type {};

template <typename Sorter, typename Group, typename Key, typename Value>
struct is_key_value_sorter<
Sorter, Group, Key, Value,
std::enable_if_t<
std::is_same_v<std::invoke_result_t<Sorter, Group, Key, Value>,
std::tuple<Key, Value>> &&
sycl::is_group_v<Group>>> : std::true_type {};

} // namespace detail

// ---- sort_over_group
Expand Down Expand Up @@ -131,6 +144,48 @@ joint_sort(experimental::group_with_scratchpad<Group, Extent> exec, Iter first,
default_sorters::joint_sorter<>(exec.get_memory()));
}

template <typename Group, typename KeyTy, typename ValueTy, typename Sorter>
std::enable_if_t<
detail::is_key_value_sorter<Sorter, Group, KeyTy, ValueTy>::value,
std::tuple<KeyTy, ValueTy>>
sort_key_value_over_group([[maybe_unused]] Group g, [[maybe_unused]] KeyTy key,
[[maybe_unused]] ValueTy value,
[[maybe_unused]] Sorter sorter) {
#ifdef __SYCL_DEVICE_ONLY__
return sorter(g, key, value);
#else
throw sycl::exception(
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
"Group algorithms are not supported on host device.");
#endif
}

template <typename Group, typename KeyTy, typename ValueTy, typename Compare,
std::size_t Extent>
std::enable_if_t<
!detail::is_key_value_sorter<Compare, Group, KeyTy, ValueTy>::value,
std::tuple<KeyTy, ValueTy>>
sort_key_value_over_group(
experimental::group_with_scratchpad<Group, Extent> exec, KeyTy key,
ValueTy value, Compare comp) {
return sort_key_value_over_group(
exec.get_group(), key, value,
default_sorters::group_key_value_sorter<KeyTy, ValueTy, Compare>(
exec.get_memory(), comp));
}

template <typename KeyTy, typename ValueTy, typename Group, std::size_t Extent>
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>,
std::tuple<KeyTy, ValueTy>>
sort_key_value_over_group(
experimental::group_with_scratchpad<Group, Extent> exec, KeyTy key,
ValueTy value) {
return sort_key_value_over_group(
exec.get_group(), key, value,
default_sorters::group_key_value_sorter<KeyTy, ValueTy>(
exec.get_memory()));
}

} // namespace ext::oneapi::experimental
} // namespace _V1
} // namespace sycl
Expand Down
50 changes: 50 additions & 0 deletions sycl/test-e2e/GroupAlgorithm/SYCL2020/group_sort/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

#include <sycl/detail/core.hpp>
#include <sycl/ext/oneapi/experimental/group_sort.hpp>

#pragma once

namespace oneapi_exp = sycl::ext::oneapi::experimental;

enum class UseGroupT { SubGroup = true, WorkGroup = false };

// these classes are needed to pass non-type template parameters to KernelName
template <int> class IntWrapper;
template <UseGroupT> class UseGroupWrapper;

class CustomType {
public:
CustomType(size_t Val) : MVal(Val) {}
CustomType() : MVal(0) {}

bool operator<(const CustomType &RHS) const { return MVal < RHS.MVal; }
bool operator>(const CustomType &RHS) const { return MVal > RHS.MVal; }
bool operator==(const CustomType &RHS) const { return MVal == RHS.MVal; }

private:
size_t MVal = 0;
};

template <class T> struct ConvertToSimpleType {
using Type = T;
};

// Dummy overloads for CustomType which is not supported by radix sorter
template <> struct ConvertToSimpleType<CustomType> {
using Type = int;
};

template <class SorterT> struct ConvertToSortingOrder;

template <class T> struct ConvertToSortingOrder<std::greater<T>> {
static const auto Type = oneapi_exp::sorting_order::descending;
};

template <class T> struct ConvertToSortingOrder<std::less<T>> {
static const auto Type = oneapi_exp::sorting_order::ascending;
};

constexpr size_t ReqSubGroupSize = 8;

template <typename...> class KernelNameOverGroup;
template <typename...> class KernelNameJoint;
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <sycl/detail/core.hpp>

#include "common.hpp"
#include <sycl/builtins.hpp>
#include <sycl/ext/oneapi/experimental/group_sort.hpp>
#include <sycl/group_algorithm.hpp>
Expand All @@ -39,30 +40,6 @@
#include <random>
#include <vector>

namespace oneapi_exp = sycl::ext::oneapi::experimental;

template <typename...> class KernelNameOverGroup;
template <typename...> class KernelNameJoint;

enum class UseGroupT { SubGroup = true, WorkGroup = false };

// these classes are needed to pass non-type template parameters to KernelName
template <int> class IntWrapper;
template <UseGroupT> class UseGroupWrapper;

class CustomType {
public:
CustomType(size_t Val) : MVal(Val) {}
CustomType() : MVal(0) {}

bool operator<(const CustomType &RHS) const { return MVal < RHS.MVal; }
bool operator>(const CustomType &RHS) const { return MVal > RHS.MVal; }
bool operator==(const CustomType &RHS) const { return MVal == RHS.MVal; }

private:
size_t MVal = 0;
};

#if VERSION == 1
template <class CompT, class T> struct RadixSorterType;

Expand All @@ -86,29 +63,8 @@ template <> struct RadixSorterType<std::greater<CustomType>, CustomType> {
using Type =
oneapi_exp::radix_sorter<int, oneapi_exp::sorting_order::descending>;
};
#else
template <class T> struct ConvertToSimpleType {
using Type = T;
};

// Dummy overloads for CustomType which is not supported by radix sorter
template <> struct ConvertToSimpleType<CustomType> {
using Type = int;
};

template <class SorterT> struct ConvertToSortingOrder;

template <class T> struct ConvertToSortingOrder<std::greater<T>> {
static const auto Type = oneapi_exp::sorting_order::descending;
};

template <class T> struct ConvertToSortingOrder<std::less<T>> {
static const auto Type = oneapi_exp::sorting_order::ascending;
};
#endif

constexpr size_t ReqSubGroupSize = 8;

template <UseGroupT UseGroup, int Dims, class T, class Compare>
void RunJointSort(sycl::queue &Q, const std::vector<T> &DataToSort,
const Compare &Comp) {
Expand Down
Loading
Loading