Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize find_first_of for 4 and 8 byte elements #4587

Merged
merged 23 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
28fbeee
Vectorize `find_first_of` for 4 and 8 byte elements
AlexGuteniev Apr 13, 2024
721bfed
Format
AlexGuteniev Apr 13, 2024
12e08b4
fix x86 build
AlexGuteniev Apr 13, 2024
7edceb7
We don't actually need dependent `false`
AlexGuteniev Apr 14, 2024
3963bea
Namespace and some renames avoid wrapping
AlexGuteniev Apr 14, 2024
810c695
Swap _Tmp1 and _Tmp2
AlexGuteniev Apr 14, 2024
df433ba
Format
AlexGuteniev Apr 14, 2024
c5e6c9e
Swap other _Tmp1 and _Tmp2 too
AlexGuteniev Apr 14, 2024
6781576
-newline
AlexGuteniev Apr 14, 2024
4240874
Don't have _Needle_length_el == 1 code path
AlexGuteniev Apr 14, 2024
5122201
spelling
AlexGuteniev Apr 14, 2024
0028fa3
missing include
AlexGuteniev Apr 14, 2024
710a7b4
unreachable
AlexGuteniev Apr 14, 2024
09a2973
Drop unnecessary `typename` when `using`.
StephanTLavavej Apr 14, 2024
39bcff8
Add `noexcept`.
StephanTLavavej Apr 14, 2024
42a5675
`__48_impl` => `__4_8_impl`
StephanTLavavej Apr 14, 2024
a528021
more ARM64EC guards
AlexGuteniev Apr 15, 2024
1d71a47
Use uppercase `_Ugly` names.
StephanTLavavej Apr 15, 2024
cbd0d6e
After checking `_Amount == 8`, directly say `8`.
StephanTLavavej Apr 15, 2024
6cb45cb
Mark `_Val` as `const`.
StephanTLavavej Apr 15, 2024
b447a9b
Remove `const` from `__m256i` return type.
StephanTLavavej Apr 15, 2024
1cae60f
`!_mm256_testz_si256(ARGS)` => `_mm256_testz_si256(ARGS) == 0`
StephanTLavavej Apr 15, 2024
74990a6
Revert "`!_mm256_testz_si256(ARGS)` => `_mm256_testz_si256(ARGS) == 0`"
StephanTLavavej Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions benchmarks/src/find_first_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,14 @@ void bm(benchmark::State& state) {
}
}

#define ARGS \
Args({2, 3}) \
->Args({7, 4}) \
->Args({9, 3}) \
->Args({22, 5}) \
->Args({58, 2}) \
->Args({102, 4}) \
->Args({325, 1}) \
->Args({1011, 11}) \
->Args({3056, 7});

BENCHMARK(bm<uint8_t>)->ARGS;
BENCHMARK(bm<uint16_t>)->ARGS;
void common_args(auto bm) {
bm->Args({2, 3})->Args({7, 4})->Args({9, 3})->Args({22, 5})->Args({58, 2});
bm->Args({102, 4})->Args({325, 1})->Args({1011, 11})->Args({3056, 7});
}

BENCHMARK(bm<uint8_t>)->Apply(common_args);
BENCHMARK(bm<uint16_t>)->Apply(common_args);
BENCHMARK(bm<uint32_t>)->Apply(common_args);
BENCHMARK(bm<uint64_t>)->Apply(common_args);

BENCHMARK_MAIN();
15 changes: 12 additions & 3 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_4(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;


__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
Expand Down Expand Up @@ -202,6 +207,12 @@ _Ty1* _Find_first_of_vectorized(
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 4) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_4(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 8) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_8(_First1, _Last1, _First2, _Last2)));
} else {
static_assert(_Always_false<_Ty1>, "Unexpected size");
}
Expand Down Expand Up @@ -230,9 +241,7 @@ _INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

// Can we activate the vector algorithms for find_first_of?
template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_find_first_of_is_safe =
_Equal_memcmp_is_safe<_It1, _It2, _Pr> // can replace value comparison with bitwise comparison
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible size
constexpr bool _Vector_alg_in_find_first_of_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr>;

// Can we activate the vector algorithms for replace?
template <class _Iter, class _Ty1>
Expand Down
256 changes: 241 additions & 15 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ namespace {
};

__m256i _Avx2_tail_mask_32(const size_t _Count_in_dwords) noexcept {
// _Count_in_dwords must be within [1, 7].
static constexpr unsigned int _Tail_masks[14] = {~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0};
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_Tail_masks + (7 - _Count_in_dwords)));
// _Count_in_dwords must be within [0, 8].
static constexpr unsigned int _Tail_masks[16] = {
~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0};
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_Tail_masks + (8 - _Count_in_dwords)));
}
} // namespace
#endif // !defined(_M_ARM64EC)
Expand Down Expand Up @@ -2038,7 +2039,26 @@ namespace {
}

template <class _Ty>
const void* __stdcall __std_find_first_of_trivial_impl(
const void* __stdcall __std_find_first_of_trivial_fallback(
const void* _First1, const void* const _Last1, const void* const _First2, const void* const _Last2) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
auto _Ptr_haystack = static_cast<const _Ty*>(_First1);
const auto _Ptr_haystack_end = static_cast<const _Ty*>(_Last1);
const auto _Ptr_needle = static_cast<const _Ty*>(_First2);
const auto _Ptr_needle_end = static_cast<const _Ty*>(_Last2);

for (; _Ptr_haystack != _Ptr_haystack_end; ++_Ptr_haystack) {
for (auto _Ptr = _Ptr_needle; _Ptr != _Ptr_needle_end; ++_Ptr) {
if (*_Ptr_haystack == *_Ptr) {
return _Ptr_haystack;
}
}
}

return _Ptr_haystack;
}

template <class _Ty>
const void* __stdcall __std_find_first_of_trivial_pcmpestri_impl(
const void* _First1, const void* const _Last1, const void* const _First2, const void* const _Last2) noexcept {
#ifndef _M_ARM64EC
if (_Use_sse42()) {
Expand Down Expand Up @@ -2175,21 +2195,217 @@ namespace {
}
}
#endif // !_M_ARM64EC
return __std_find_first_of_trivial_fallback<_Ty>(_First1, _Last1, _First2, _Last2);
}

auto _Ptr_haystack = static_cast<const _Ty*>(_First1);
const auto _Ptr_haystack_end = static_cast<const _Ty*>(_Last1);
const auto _Ptr_needle = static_cast<const _Ty*>(_First2);
const auto _Ptr_needle_end = static_cast<const _Ty*>(_Last2);
struct _Find_first_of_traits_4 : _Find_traits_4 {
using _Ty = uint32_t;

template <size_t _Amount>
static __m256i _Spread_avx(__m256i _Val, const size_t _Needle_length_el) noexcept {
if constexpr (_Amount == 1) {
return _mm256_broadcastd_epi32(_mm256_castsi256_si128(_Val));
} else if constexpr (_Amount == 2) {
return _mm256_broadcastq_epi64(_mm256_castsi256_si128(_Val));
} else if constexpr (_Amount == 4) {
if (_Needle_length_el < 4) {
_Val = _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(0, 2, 1, 0));
}

for (; _Ptr_haystack != _Ptr_haystack_end; ++_Ptr_haystack) {
for (auto _Ptr = _Ptr_needle; _Ptr != _Ptr_needle_end; ++_Ptr) {
if (*_Ptr_haystack == *_Ptr) {
return _Ptr_haystack;
return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 1, 0));
} else if constexpr (_Amount == 8) {
if (_Needle_length_el < _Amount) {
const __m256i _Mask = _Avx2_tail_mask_32(_Needle_length_el);
// zero unused elements in sequenctial permutation mask, so will be filled by 1st
const __m256i _Perm = _mm256_and_si256(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0), _Mask);
_Val = _mm256_permutevar8x32_epi32(_Val, _Perm);
}

return _Val;
} else {
static_assert(_Amount != _Amount, "Unexpected amount");
}
}

return _Ptr_haystack;
template <size_t _Amount>
static __m256i _Shuffle_avx(const __m256i _Val) noexcept {
if constexpr (_Amount == 1) {
return _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(2, 3, 0, 1));
} else if constexpr (_Amount == 2) {
return _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(1, 0, 3, 2));
} else if constexpr (_Amount == 4) {
return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 3, 2));
} else {
static_assert(_Amount != _Amount, "Unexpected amount");
}
}
};

struct _Find_first_of_traits_8 : _Find_traits_8 {
using _Ty = uint64_t;

template <size_t _Amount>
static __m256i _Spread_avx(__m256i _Val, const size_t _Needle_length_el) noexcept {
if constexpr (_Amount == 1) {
return _mm256_broadcastq_epi64(_mm256_castsi256_si128(_Val));
} else if constexpr (_Amount == 2) {
return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 1, 0));
} else if constexpr (_Amount == 4) {
if (_Needle_length_el < 4) {
return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(0, 2, 1, 0));
}

return _Val;
} else {
static_assert(_Amount != _Amount, "Unexpected amount");
}
}

template <size_t _Amount>
static __m256i _Shuffle_avx(const __m256i _Val) noexcept {
if constexpr (_Amount == 1) {
return _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(1, 0, 3, 2));
} else if constexpr (_Amount == 2) {
return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 3, 2));
} else {
static_assert(_Amount != _Amount, "Unexpected amount");
}
}
};

template <class _Traits, size_t _Needle_length_el_magnitude>
const __m256i __std_find_first_of_trivial_shuffle_step(const __m256i _Data1, const __m256i _Data2s0) {
__m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2s0);
if constexpr (_Needle_length_el_magnitude >= 2) {
const __m256i _Data2s1 = _Traits::_Shuffle_avx<1>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s1));
if constexpr (_Needle_length_el_magnitude >= 4) {
const __m256i _Data2s2 = _Traits::_Shuffle_avx<2>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s2));
const __m256i _Data2s3 = _Traits::_Shuffle_avx<1>(_Data2s2);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s3));
if constexpr (_Needle_length_el_magnitude >= 8) {
const __m256i _Data2s4 = _Traits::_Shuffle_avx<4>(_Data2s0);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s4));
const __m256i _Data2s5 = _Traits::_Shuffle_avx<1>(_Data2s4);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s5));
const __m256i _Data2s6 = _Traits::_Shuffle_avx<2>(_Data2s4);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s6));
const __m256i _Data2s7 = _Traits::_Shuffle_avx<1>(_Data2s6);
_Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s7));
}
}
}
return _Eq;
}

template <class _Traits, size_t _Needle_length_el_magnitude>
const void* __std_find_first_of_trivial_shuffle_impl(
const void* _First1, const void* const _Last1, const void* const _First2, const size_t _Needle_length_el) {
using _Ty = typename _Traits::_Ty;
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const __m256i _Data2 = _mm256_maskload_epi32(
reinterpret_cast<const int*>(_First2), _Avx2_tail_mask_32(_Needle_length_el * (sizeof(_Ty) / 4)));
const __m256i _Data2s0 = _Traits::_Spread_avx<_Needle_length_el_magnitude>(_Data2, _Needle_length_el);

const size_t _Haystack_length = _Byte_length(_First1, _Last1);

const void* _Stop1 = _First1;
_Advance_bytes(_Stop1, _Haystack_length & ~size_t{0x1F});

for (; _First1 != _Stop1; _Advance_bytes(_First1, 32)) {
const __m256i _Data1 = _mm256_loadu_si256(static_cast<const __m256i*>(_First1));
const __m256i _Eq =
__std_find_first_of_trivial_shuffle_step<_Traits, _Needle_length_el_magnitude>(_Data1, _Data2s0);
const int _Bingo = _mm256_movemask_epi8(_Eq);

if (_Bingo != 0) {
const unsigned long _Offset = _tzcnt_u32(_Bingo);
_Advance_bytes(_First1, _Offset);
return _First1;
}
}

if (const size_t _Haystack_tail_length = _Haystack_length & 0x1C; _Haystack_tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Haystack_tail_length >> 2);
const __m256i _Data1 = _mm256_maskload_epi32(static_cast<const int*>(_First1), _Tail_mask);
const __m256i _Eq =
__std_find_first_of_trivial_shuffle_step<_Traits, _Needle_length_el_magnitude>(_Data1, _Data2s0);
const int _Bingo = _mm256_movemask_epi8(_mm256_and_si256(_Eq, _Tail_mask));

if (_Bingo != 0) {
const unsigned long _Offset = _tzcnt_u32(_Bingo);
_Advance_bytes(_First1, _Offset);
return _First1;
}

_Advance_bytes(_First1, _Haystack_tail_length);
}

return _First1;
}

template <class _Traits>
const void* __stdcall __std_find_first_of_trivial_48_impl(const void* const _First1, const void* const _Last1,
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const void* const _First2, const void* const _Last2) noexcept {
using _Ty = typename _Traits::_Ty;
#ifndef _M_ARM64EC
if (_Use_avx2()) {
_Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414

const size_t _Needle_length = _Byte_length(_First2, _Last2);
const int _Needle_length_el = static_cast<int>(_Needle_length / sizeof(_Ty));

// Special handling of small needle
// The generic approach could also handle it but with worse performance
if (_Needle_length_el == 0) {
return _Last1;
} else if (_Needle_length_el == 1) {
// This is expected to be done on an upper level with better efficeiency
return __std_find_first_of_trivial_shuffle_impl<_Traits, 1>(
_First1, _Last1, _First2, _Needle_length_el);
} else if (_Needle_length_el == 2) {
return __std_find_first_of_trivial_shuffle_impl<_Traits, 2>(
_First1, _Last1, _First2, _Needle_length_el);
} else if (_Needle_length_el <= 4) {
return __std_find_first_of_trivial_shuffle_impl<_Traits, 4>(
_First1, _Last1, _First2, _Needle_length_el);
} else if (_Needle_length_el <= 8) {
if constexpr (sizeof(_Ty) == 4) {
return __std_find_first_of_trivial_shuffle_impl<_Traits, 8>(
_First1, _Last1, _First2, _Needle_length_el);
}
}

// Generic approach
const size_t _Needle_length_tail = _Needle_length & 0x1C;
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Needle_length_tail >> 2);

const void* _Stop2 = _First2;
_Advance_bytes(_Stop2, _Needle_length & ~size_t{0x1F});

for (auto _Ptr1 = static_cast<const _Ty*>(_First1); _Ptr1 != _Last1; ++_Ptr1) {
const auto _Data1 = _Traits::_Set_avx(*_Ptr1);
for (auto _Ptr2 = _First2; _Ptr2 != _Stop2; _Advance_bytes(_Ptr2, 32)) {
const __m256i _Data2 = _mm256_loadu_si256(static_cast<const __m256i*>(_Ptr2));
const __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2);
if (!_mm256_testz_si256(_Eq, _Eq)) {
return _Ptr1;
}
}

if (_Needle_length_tail != 0) {
const __m256i _Data2 = _mm256_maskload_epi32(static_cast<const int*>(_Stop2), _Tail_mask);
const __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2);
if (!_mm256_testz_si256(_Eq, _Tail_mask)) {
return _Ptr1;
}
}
}

return _Last1;
}
#endif // !_M_ARM64EC
return __std_find_first_of_trivial_fallback<_Ty>(_First1, _Last1, _First2, _Last2);
}


Expand Down Expand Up @@ -2352,12 +2568,22 @@ __declspec(noalias) size_t

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept {
return __std_find_first_of_trivial_impl<uint8_t>(_First1, _Last1, _First2, _Last2);
return __std_find_first_of_trivial_pcmpestri_impl<uint8_t>(_First1, _Last1, _First2, _Last2);
}

const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept {
return __std_find_first_of_trivial_impl<uint16_t>(_First1, _Last1, _First2, _Last2);
return __std_find_first_of_trivial_pcmpestri_impl<uint16_t>(_First1, _Last1, _First2, _Last2);
}

const void* __stdcall __std_find_first_of_trivial_4(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept {
return __std_find_first_of_trivial_48_impl<_Find_first_of_traits_4>(_First1, _Last1, _First2, _Last2);
}

const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept {
return __std_find_first_of_trivial_48_impl<_Find_first_of_traits_8>(_First1, _Last1, _First2, _Last2);
}

__declspec(noalias) size_t
Expand Down