Skip to content

Commit

Permalink
Merge pull request #105 from r-devulap/key-value
Browse files Browse the repository at this point in the history
Add key-value sort to runtime dispatch
  • Loading branch information
r-devulap committed Nov 17, 2023
2 parents 15c3379 + aba8371 commit cb4358f
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 19 deletions.
1 change: 1 addition & 0 deletions benchmarks/bench-all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
#include "bench-partial-qsort.hpp"
#include "bench-qselect.hpp"
#include "bench-qsort.hpp"
#include "bench-keyvalue.hpp"
48 changes: 48 additions & 0 deletions benchmarks/bench-keyvalue.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "x86simdsort-scalar.h"

template <typename T, class... Args>
static void scalarkvsort(benchmark::State &state, Args &&...args)
{
// Get args
auto args_tuple = std::make_tuple(std::move(args)...);
size_t arrsize = std::get<0>(args_tuple);
std::string arrtype = std::get<1>(args_tuple);
// set up array
std::vector<T> key = get_array<T>(arrtype, arrsize);
std::vector<T> val = get_array<T>("random", arrsize);
std::vector<T> key_bkp = key;
// benchmark
for (auto _ : state) {
xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false);
state.PauseTiming();
key = key_bkp;
state.ResumeTiming();
}
}

template <typename T, class... Args>
static void simdkvsort(benchmark::State &state, Args &&...args)
{
auto args_tuple = std::make_tuple(std::move(args)...);
size_t arrsize = std::get<0>(args_tuple);
std::string arrtype = std::get<1>(args_tuple);
// set up array
std::vector<T> key = get_array<T>(arrtype, arrsize);
std::vector<T> val = get_array<T>("random", arrsize);
std::vector<T> key_bkp = key;
// benchmark
for (auto _ : state) {
x86simdsort::keyvalue_qsort(key.data(), val.data(), arrsize);
state.PauseTiming();
key = key_bkp;
state.ResumeTiming();
}
}

#define BENCH_BOTH_KVSORT(type) \
BENCH_SORT(simdkvsort, type) \
BENCH_SORT(scalarkvsort, type)

BENCH_BOTH_KVSORT(uint64_t)
BENCH_BOTH_KVSORT(int64_t)
BENCH_BOTH_KVSORT(double)
9 changes: 9 additions & 0 deletions lib/x86simdsort-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ namespace avx512 {
// quicksort
template <typename T>
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void
Expand All @@ -30,6 +33,9 @@ namespace avx2 {
// quicksort
template <typename T>
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void
Expand All @@ -51,6 +57,9 @@ namespace scalar {
// quicksort
template <typename T>
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
// key-value quicksort
template <typename T1, typename T2>
XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
// quickselect
template <typename T>
XSS_HIDE_SYMBOL void
Expand Down
28 changes: 28 additions & 0 deletions lib/x86simdsort-scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,27 @@
#include <numeric>

namespace xss {
namespace utils {
/* O(1) permute array in place: stolen from
* http://www.davidespataro.it/apply-a-permutation-to-a-vector */
template<typename T>
void apply_permutation_in_place(T* arr, std::vector<size_t> arg)
{
for(size_t i = 0 ; i < arg.size() ; i++) {
size_t curr = i;
size_t next = arg[curr];
while(next != i)
{
std::swap(arr[curr], arr[next]);
arg[curr] = curr;
curr = next;
next = arg[next];
}
arg[curr] = curr;
}
}
} // utils

namespace scalar {
template <typename T>
void qsort(T *arr, size_t arrsize, bool hasnan)
Expand Down Expand Up @@ -57,6 +78,13 @@ namespace scalar {
compare_arg<T, std::less<T>>(arr));
return arg;
}
template <typename T1, typename T2>
void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan)
{
std::vector<size_t> arg = argsort(key, arrsize, hasnan);
utils::apply_permutation_in_place(key, arg);
utils::apply_permutation_in_place(val, arg);
}

} // namespace scalar
} // namespace xss
18 changes: 18 additions & 0 deletions lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// SKX specific routines:
#include "avx512-32bit-qsort.hpp"
#include "avx512-64bit-keyvaluesort.hpp"
#include "avx512-64bit-argsort.hpp"
#include "avx512-64bit-qsort.hpp"
#include "x86simdsort-internal.h"
Expand Down Expand Up @@ -32,6 +33,14 @@
return avx512_argselect(arr, k, arrsize, hasnan); \
}

#define DEFINE_KEYVALUE_METHODS(type1, type2) \
template <> \
void keyvalue_qsort(type1 *key, type2* val, size_t arrsize, bool hasnan) \
{ \
avx512_qsort_kv(key, val, arrsize, hasnan); \
} \


namespace xss {
namespace avx512 {
DEFINE_ALL_METHODS(uint32_t)
Expand All @@ -40,5 +49,14 @@ namespace avx512 {
DEFINE_ALL_METHODS(uint64_t)
DEFINE_ALL_METHODS(int64_t)
DEFINE_ALL_METHODS(double)
DEFINE_KEYVALUE_METHODS(double, uint64_t)
DEFINE_KEYVALUE_METHODS(double, int64_t)
DEFINE_KEYVALUE_METHODS(double, double)
DEFINE_KEYVALUE_METHODS(uint64_t, uint64_t)
DEFINE_KEYVALUE_METHODS(uint64_t, int64_t)
DEFINE_KEYVALUE_METHODS(uint64_t, double)
DEFINE_KEYVALUE_METHODS(int64_t, uint64_t)
DEFINE_KEYVALUE_METHODS(int64_t, int64_t)
DEFINE_KEYVALUE_METHODS(int64_t, double)
} // namespace avx512
} // namespace xss
40 changes: 39 additions & 1 deletion lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ dispatch_requested(std::string_view cpurequested,
return false;
}

namespace x86simdsort {

#define CAT_(a, b) a##b
#define CAT(a, b) CAT_(a, b)

Expand Down Expand Up @@ -120,6 +122,33 @@ dispatch_requested(std::string_view cpurequested,
return; \
} \
} \
} \

#define DISPATCH_KEYVALUE_SORT(TYPE1, TYPE2, ISA) \
static void (CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))(TYPE1*, TYPE2*, size_t, bool) = NULL; \
template <> \
void keyvalue_qsort(TYPE1 *key, TYPE2* val, size_t arrsize, bool hasnan) \
{ \
(CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))(key, val, arrsize, hasnan); \
} \
static __attribute__((constructor)) void \
CAT(CAT(resolve_keyvalue_qsort_, TYPE1), TYPE2)(void) \
{ \
CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) = &xss::scalar::keyvalue_qsort<TYPE1, TYPE2>; \
__builtin_cpu_init(); \
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
if constexpr (dispatch_requested("avx512", ISA)) { \
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) = &xss::avx512::keyvalue_qsort<TYPE1, TYPE2>; \
return; \
} \
} \
if constexpr (dispatch_requested("avx2", ISA)) { \
if (preferred_cpu.find("avx2") != std::string_view::npos) { \
CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) = &xss::avx2::keyvalue_qsort<TYPE1, TYPE2>; \
return; \
} \
} \
}

#define ISA_LIST(...) \
Expand All @@ -128,7 +157,6 @@ dispatch_requested(std::string_view cpurequested,
__VA_ARGS__ \
}

namespace x86simdsort {
#ifdef __FLT16_MAX__
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr"))
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr"))
Expand Down Expand Up @@ -168,4 +196,14 @@ DISPATCH_ALL(argselect,
(ISA_LIST("avx512_skx")),
(ISA_LIST("avx512_skx")))

DISPATCH_KEYVALUE_SORT(uint64_t, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(uint64_t, uint64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(uint64_t, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, uint64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(int64_t, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, int64_t, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, double, (ISA_LIST("avx512_skx")))
DISPATCH_KEYVALUE_SORT(double, uint64_t, (ISA_LIST("avx512_skx")))

} // namespace x86simdsort
5 changes: 5 additions & 0 deletions lib/x86simdsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@ template <typename T>
XSS_EXPORT_SYMBOL std::vector<size_t>
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);

// argselect
template <typename T1, typename T2>
XSS_EXPORT_SYMBOL void
keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);

} // namespace x86simdsort
#endif
3 changes: 3 additions & 0 deletions run-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
elif "argsort" in args.benchcompare:
baseline = "scalarargsort.*" + filterb
contender = "simdargsort.*" + filterb
elif "keyvalue" in args.benchcompare:
baseline = "scalarkvsort.*" + filterb
contender = "simdkvsort.*" + filterb
else:
parser.print_help(sys.stderr)
parser.error("ERROR: Unknown argument '%s'" % args.benchcompare)
Expand Down
10 changes: 6 additions & 4 deletions src/avx512-64bit-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,8 @@ heapify(type1_t *keys, type2_t *indexes, arrsize_t idx, arrsize_t size)
arrsize_t i = idx;
while (true) {
arrsize_t j = 2 * i + 1;
if (j >= size || j < 0) { break; }
int k = j + 1;
if (j >= size) { break; }
arrsize_t k = j + 1;
if (k < size && keys[j] < keys[k]) { j = k; }
if (keys[j] < keys[i]) { break; }
std::swap(keys[i], keys[j]);
Expand All @@ -558,8 +558,9 @@ template <typename vtype1,
X86_SIMD_SORT_INLINE void
heap_sort(type1_t *keys, type2_t *indexes, arrsize_t size)
{
for (arrsize_t i = size / 2 - 1; i >= 0; i--) {
for (arrsize_t i = size / 2 - 1; ; i--) {
heapify<vtype1, vtype2>(keys, indexes, i, size);
if (i == 0) { break; }
}
for (arrsize_t i = size - 1; i > 0; i--) {
std::swap(keys[0], keys[i]);
Expand Down Expand Up @@ -614,8 +615,9 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys,

template <typename T1, typename T2>
X86_SIMD_SORT_INLINE void
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize)
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
{
UNUSED(hasnan);
if (arrsize > 1) {
if constexpr (std::is_floating_point_v<T1>) {
arrsize_t nan_count
Expand Down
6 changes: 6 additions & 0 deletions tests/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ libtests += static_library('tests_qsort',
include_directories : [lib, utils],
)

libtests += static_library('tests_kvsort',
files('test-keyvalue.cpp', ),
dependencies: gtest_dep,
include_directories : [lib, utils],
)

#if cancompilefp16
# libtests += static_library('tests_qsortfp16',
# files('test-qsortfp16.cpp', ),
Expand Down
67 changes: 67 additions & 0 deletions tests/test-keyvalue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*******************************************
* * Copyright (C) 2022-2023 Intel Corporation
* * SPDX-License-Identifier: BSD-3-Clause
* *******************************************/

#include "rand_array.h"
#include "x86simdsort.h"
#include "x86simdsort-scalar.h"
#include <gtest/gtest.h>

template <typename T>
class simdkvsort : public ::testing::Test {
public:
simdkvsort()
{
std::iota(arrsize.begin(), arrsize.end(), 1);
arrtype = {"random",
"constant",
"sorted",
"reverse",
"smallrange",
"max_at_the_end",
"rand_max"};
}
std::vector<std::string> arrtype;
std::vector<size_t> arrsize = std::vector<size_t>(1024);
};

TYPED_TEST_SUITE_P(simdkvsort);

TYPED_TEST_P(simdkvsort, test_kvsort)
{
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
for (auto size : this->arrsize) {
std::vector<T1> key = get_array<T1>(type, size);
std::vector<T2> val = get_array<T2>(type, size);
std::vector<T1> key_bckp = key;
std::vector<T2> val_bckp = val;
x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan);
xss::scalar::keyvalue_qsort(key_bckp.data(), val_bckp.data(), size, hasnan);
ASSERT_EQ(key, key_bckp);
const bool hasDuplicates = std::adjacent_find(key.begin(), key.end()) != key.end();
if (!hasDuplicates) {
ASSERT_EQ(val, val_bckp);
}
key.clear(); val.clear();
key_bckp.clear(); val_bckp.clear();
}
}
}

REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort);

using QKVSortTestTypes = testing::Types<std::tuple<double, double>,
std::tuple<double, uint64_t>,
std::tuple<double, int64_t>,
std::tuple<uint64_t, double>,
std::tuple<uint64_t, uint64_t>,
std::tuple<uint64_t, int64_t>,
std::tuple<int64_t, double>,
std::tuple<int64_t, uint64_t>,
std::tuple<int64_t, int64_t>>;

INSTANTIATE_TYPED_TEST_SUITE_P(xss, simdkvsort, QKVSortTestTypes);
Loading

0 comments on commit cb4358f

Please sign in to comment.