Skip to content

Commit

Permalink
Use scalar emulation of gather instruction for arg methods
Browse files Browse the repository at this point in the history
  • Loading branch information
r-devulap committed Aug 21, 2023
1 parent 0890de5 commit 323f247
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 60 deletions.
27 changes: 12 additions & 15 deletions src/avx512-64bit-argsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N)
typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01;
argzmm_t argzmm1 = argtype::loadu(arg);
argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8);
zmm_t arrzmm1 = vtype::template i64gather<sizeof(type_t)>(argzmm1, arr);
zmm_t arrzmm1 = vtype::i64gather(arr, arg);
zmm_t arrzmm2 = vtype::template mask_i64gather<sizeof(type_t)>(
vtype::zmm_max(), load_mask, argzmm2, arr);
arrzmm1 = sort_zmm_64bit<vtype, argtype>(arrzmm1, argzmm1);
Expand All @@ -111,7 +111,7 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N)
#pragma GCC unroll 2
for (int ii = 0; ii < 2; ++ii) {
argzmm[ii] = argtype::loadu(arg + 8 * ii);
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
}

Expand Down Expand Up @@ -154,7 +154,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
#pragma GCC unroll 4
for (int ii = 0; ii < 4; ++ii) {
argzmm[ii] = argtype::loadu(arg + 8 * ii);
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
}

Expand Down Expand Up @@ -206,7 +206,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
//#pragma GCC unroll 8
// for (int ii = 0; ii < 8; ++ii) {
// argzmm[ii] = argtype::loadu(arg + 8*ii);
// arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr);
// arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
// }
//
Expand Down Expand Up @@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr,
// median of 8
int64_t size = (right - left) / 8;
using zmm_t = typename vtype::zmm_t;
// TODO: Use gather here too:
__m512i rand_index = _mm512_set_epi64(arg[left + size],
arg[left + 2 * size],
arg[left + 3 * size],
arg[left + 4 * size],
arg[left + 5 * size],
arg[left + 6 * size],
arg[left + 7 * size],
arg[left + 8 * size]);
zmm_t rand_vec
= vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
zmm_t rand_vec = vtype::set(arr[arg[left + size]],
arr[arg[left + 2 * size]],
arr[arg[left + 3 * size]],
arr[arg[left + 4 * size]],
arr[arg[left + 5 * size]],
arr[arg[left + 6 * size]],
arr[arg[left + 7 * size]],
arr[arg[left + 8 * size]]);
// pivot will never be a nan, since there are no nan's!
zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
return ((type_t *)&sort)[4];
Expand Down
158 changes: 128 additions & 30 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,22 @@ struct ymm_vector<float> {
{
return _mm256_set1_ps(type_max());
}

static zmmi_t
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
{
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
}
static zmm_t set(type_t v1,
type_t v2,
type_t v3,
type_t v4,
type_t v5,
type_t v6,
type_t v7,
type_t v8)
{
return _mm256_set_ps(v1, v2, v3, v4, v5, v6, v7, v8);
}
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
{
return _kxor_mask8(x, y);
Expand Down Expand Up @@ -80,10 +90,16 @@ struct ymm_vector<float> {
{
return _mm512_mask_i64gather_ps(src, mask, index, base, scale);
}
template <int scale>
static zmm_t i64gather(__m512i index, void const *base)
static zmm_t i64gather(type_t *arr, int64_t *ind)
{
return _mm512_i64gather_ps(index, base, scale);
return set(arr[ind[7]],
arr[ind[6]],
arr[ind[5]],
arr[ind[4]],
arr[ind[3]],
arr[ind[2]],
arr[ind[1]],
arr[ind[0]]);
}
static zmm_t loadu(void const *mem)
{
Expand Down Expand Up @@ -189,6 +205,17 @@ struct ymm_vector<uint32_t> {
{
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
}
static zmm_t set(type_t v1,
type_t v2,
type_t v3,
type_t v4,
type_t v5,
type_t v6,
type_t v7,
type_t v8)
{
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
}
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
{
return _kxor_mask8(x, y);
Expand All @@ -215,10 +242,16 @@ struct ymm_vector<uint32_t> {
{
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
}
template <int scale>
static zmm_t i64gather(__m512i index, void const *base)
static zmm_t i64gather(type_t *arr, int64_t *ind)
{
return _mm512_i64gather_epi32(index, base, scale);
return set(arr[ind[7]],
arr[ind[6]],
arr[ind[5]],
arr[ind[4]],
arr[ind[3]],
arr[ind[2]],
arr[ind[1]],
arr[ind[0]]);
}
static zmm_t loadu(void const *mem)
{
Expand Down Expand Up @@ -318,6 +351,17 @@ struct ymm_vector<int32_t> {
{
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
}
static zmm_t set(type_t v1,
type_t v2,
type_t v3,
type_t v4,
type_t v5,
type_t v6,
type_t v7,
type_t v8)
{
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
}
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
{
return _kxor_mask8(x, y);
Expand All @@ -344,10 +388,16 @@ struct ymm_vector<int32_t> {
{
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
}
template <int scale>
static zmm_t i64gather(__m512i index, void const *base)
static zmm_t i64gather(type_t *arr, int64_t *ind)
{
return _mm512_i64gather_epi32(index, base, scale);
return set(arr[ind[7]],
arr[ind[6]],
arr[ind[5]],
arr[ind[4]],
arr[ind[3]],
arr[ind[2]],
arr[ind[1]],
arr[ind[0]]);
}
static zmm_t loadu(void const *mem)
{
Expand Down Expand Up @@ -448,6 +498,17 @@ struct zmm_vector<int64_t> {
{
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
}
static zmm_t set(type_t v1,
type_t v2,
type_t v3,
type_t v4,
type_t v5,
type_t v6,
type_t v7,
type_t v8)
{
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
}
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
{
return _kxor_mask8(x, y);
Expand All @@ -474,10 +535,16 @@ struct zmm_vector<int64_t> {
{
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
}
template <int scale>
static zmm_t i64gather(__m512i index, void const *base)
static zmm_t i64gather(type_t *arr, int64_t *ind)
{
return _mm512_i64gather_epi64(index, base, scale);
return set(arr[ind[7]],
arr[ind[6]],
arr[ind[5]],
arr[ind[4]],
arr[ind[3]],
arr[ind[2]],
arr[ind[1]],
arr[ind[0]]);
}
static zmm_t loadu(void const *mem)
{
Expand Down Expand Up @@ -566,16 +633,33 @@ struct zmm_vector<uint64_t> {
{
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
}
static zmm_t set(type_t v1,
type_t v2,
type_t v3,
type_t v4,
type_t v5,
type_t v6,
type_t v7,
type_t v8)
{
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
}
template <int scale>
static zmm_t
mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base)
{
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
}
template <int scale>
static zmm_t i64gather(__m512i index, void const *base)
static zmm_t i64gather(type_t *arr, int64_t *ind)
{
return _mm512_i64gather_epi64(index, base, scale);
return set(arr[ind[7]],
arr[ind[6]],
arr[ind[5]],
arr[ind[4]],
arr[ind[3]],
arr[ind[2]],
arr[ind[1]],
arr[ind[0]]);
}
static opmask_t knot_opmask(opmask_t x)
{
Expand Down Expand Up @@ -666,13 +750,22 @@ struct zmm_vector<double> {
{
return _mm512_set1_pd(type_max());
}

static zmmi_t
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
{
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
}

static zmm_t set(type_t v1,
type_t v2,
type_t v3,
type_t v4,
type_t v5,
type_t v6,
type_t v7,
type_t v8)
{
return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8);
}
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
{
return _mm512_maskz_loadu_pd(mask, mem);
Expand Down Expand Up @@ -704,10 +797,16 @@ struct zmm_vector<double> {
{
return _mm512_mask_i64gather_pd(src, mask, index, base, scale);
}
template <int scale>
static zmm_t i64gather(__m512i index, void const *base)
static zmm_t i64gather(type_t *arr, int64_t *ind)
{
return _mm512_i64gather_pd(index, base, scale);
return set(arr[ind[7]],
arr[ind[6]],
arr[ind[5]],
arr[ind[4]],
arr[ind[3]],
arr[ind[2]],
arr[ind[1]],
arr[ind[0]]);
}
static zmm_t loadu(void const *mem)
{
Expand Down Expand Up @@ -794,15 +893,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
// median of 8
int64_t size = (right - left) / 8;
using zmm_t = typename vtype::zmm_t;
__m512i rand_index = _mm512_set_epi64(left + size,
left + 2 * size,
left + 3 * size,
left + 4 * size,
left + 5 * size,
left + 6 * size,
left + 7 * size,
left + 8 * size);
zmm_t rand_vec = vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
zmm_t rand_vec = vtype::set(arr[left + size],
arr[left + 2 * size],
arr[left + 3 * size],
arr[left + 4 * size],
arr[left + 5 * size],
arr[left + 6 * size],
arr[left + 7 * size],
arr[left + 8 * size]);
// pivot will never be a nan, since there are no nan's!
zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
return ((type_t *)&sort)[4];
Expand Down
27 changes: 12 additions & 15 deletions src/avx512-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static inline int64_t partition_avx512(type_t *arr,

if (right - left == vtype::numlanes) {
argzmm_t argvec = argtype::loadu(arg + left);
zmm_t vec = vtype::template i64gather<sizeof(type_t)>(argvec, arr);
zmm_t vec = vtype::i64gather(arr, arg + left);
int32_t amount_gt_pivot = partition_vec<vtype>(arg,
left,
left + vtype::numlanes,
Expand All @@ -91,11 +91,9 @@ static inline int64_t partition_avx512(type_t *arr,

// first and last vtype::numlanes values are partitioned at the end
argzmm_t argvec_left = argtype::loadu(arg + left);
zmm_t vec_left
= vtype::template i64gather<sizeof(type_t)>(argvec_left, arr);
zmm_t vec_left = vtype::i64gather(arr, arg + left);
argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes));
zmm_t vec_right
= vtype::template i64gather<sizeof(type_t)>(argvec_right, arr);
zmm_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes));
// store points of the vectors
int64_t r_store = right - vtype::numlanes;
int64_t l_store = left;
Expand All @@ -113,11 +111,11 @@ static inline int64_t partition_avx512(type_t *arr,
if ((r_store + vtype::numlanes) - right < left - l_store) {
right -= vtype::numlanes;
arg_vec = argtype::loadu(arg + right);
curr_vec = vtype::template i64gather<sizeof(type_t)>(arg_vec, arr);
curr_vec = vtype::i64gather(arr, arg + right);
}
else {
arg_vec = argtype::loadu(arg + left);
curr_vec = vtype::template i64gather<sizeof(type_t)>(arg_vec, arr);
curr_vec = vtype::i64gather(arr, arg + left);
left += vtype::numlanes;
}
// partition the current vector and save it on both sides of the array
Expand Down Expand Up @@ -201,12 +199,11 @@ static inline int64_t partition_avx512_unrolled(type_t *arr,
#pragma GCC unroll 8
for (int ii = 0; ii < num_unroll; ++ii) {
argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii);
vec_left[ii] = vtype::template i64gather<sizeof(type_t)>(
argvec_left[ii], arr);
vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii);
argvec_right[ii] = argtype::loadu(
arg + (right - vtype::numlanes * (num_unroll - ii)));
vec_right[ii] = vtype::template i64gather<sizeof(type_t)>(
argvec_right[ii], arr);
vec_right[ii] = vtype::i64gather(
arr, arg + (right - vtype::numlanes * (num_unroll - ii)));
}
// store points of the vectors
int64_t r_store = right - vtype::numlanes;
Expand All @@ -228,16 +225,16 @@ static inline int64_t partition_avx512_unrolled(type_t *arr,
for (int ii = 0; ii < num_unroll; ++ii) {
arg_vec[ii]
= argtype::loadu(arg + right + ii * vtype::numlanes);
curr_vec[ii] = vtype::template i64gather<sizeof(type_t)>(
arg_vec[ii], arr);
curr_vec[ii] = vtype::i64gather(
arr, arg + right + ii * vtype::numlanes);
}
}
else {
#pragma GCC unroll 8
for (int ii = 0; ii < num_unroll; ++ii) {
arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes);
curr_vec[ii] = vtype::template i64gather<sizeof(type_t)>(
arg_vec[ii], arr);
curr_vec[ii] = vtype::i64gather(
arr, arg + left + ii * vtype::numlanes);
}
left += num_unroll * vtype::numlanes;
}
Expand Down

0 comments on commit 323f247

Please sign in to comment.