Skip to content

Commit

Permalink
Vectorize basic_string::find_first_of (#4744)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
Co-authored-by: Casey Carter <cacarter@microsoft.com>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent 91e4255 commit 77b31f7
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 64 deletions.
36 changes: 26 additions & 10 deletions benchmarks/src/find_first_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,26 @@
#include <cstdint>
#include <cstdlib>
#include <numeric>
#include <string>
#include <type_traits>
#include <vector>

using namespace std;

template <class T>
enum class AlgType : bool { std_func, str_member };

template <AlgType Alg, class T, T Start = T{'a'}>
void bm(benchmark::State& state) {
const size_t Pos = static_cast<size_t>(state.range(0));
const size_t NSize = static_cast<size_t>(state.range(1));
const size_t HSize = Pos * 2;
const size_t Which = 0;

vector<T> h(HSize, T{'.'});
vector<T> n(NSize);
iota(n.begin(), n.end(), T{'a'});
using container = conditional_t<Alg == AlgType::str_member, basic_string<T>, vector<T>>;

container h(HSize, T{'.'});
container n(NSize, T{0});
iota(n.begin(), n.end(), Start);

if (Pos >= HSize || Which >= NSize) {
abort();
Expand All @@ -29,18 +35,28 @@ void bm(benchmark::State& state) {
h[Pos] = n[Which];

for (auto _ : state) {
benchmark::DoNotOptimize(find_first_of(h.begin(), h.end(), n.begin(), n.end()));
benchmark::DoNotOptimize(h);
benchmark::DoNotOptimize(n);
if constexpr (Alg == AlgType::str_member) {
benchmark::DoNotOptimize(h.find_first_of(n.data(), 0, n.size()));
} else {
benchmark::DoNotOptimize(find_first_of(h.begin(), h.end(), n.begin(), n.end()));
}
}
}

void common_args(auto bm) {
bm->Args({2, 3})->Args({7, 4})->Args({9, 3})->Args({22, 5})->Args({58, 2});
bm->Args({102, 4})->Args({325, 1})->Args({1011, 11})->Args({3056, 7});
bm->Args({102, 4})->Args({325, 1})->Args({1011, 11})->Args({1502, 23})->Args({3056, 7});
}

BENCHMARK(bm<uint8_t>)->Apply(common_args);
BENCHMARK(bm<uint16_t>)->Apply(common_args);
BENCHMARK(bm<uint32_t>)->Apply(common_args);
BENCHMARK(bm<uint64_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std_func, uint8_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std_func, uint16_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std_func, uint32_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std_func, uint64_t>)->Apply(common_args);

BENCHMARK(bm<AlgType::str_member, char>)->Apply(common_args);
BENCHMARK(bm<AlgType::str_member, wchar_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::str_member, wchar_t, L'\x03B1'>)->Apply(common_args);

BENCHMARK_MAIN();
60 changes: 45 additions & 15 deletions stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,25 +708,55 @@ constexpr size_t _Traits_find_first_of(_In_reads_(_Hay_size) const _Traits_ptr_t
const size_t _Needle_size) noexcept {
// in [_Haystack, _Haystack + _Hay_size), look for one of [_Needle, _Needle + _Needle_size), at/after _Start_at
if (_Needle_size != 0 && _Start_at < _Hay_size) { // room for match, look for it
const auto _Hay_start = _Haystack + _Start_at;
const auto _Hay_end = _Haystack + _Hay_size;

if constexpr (_Special) {
_String_bitmap<typename _Traits::char_type> _Matches;
if (!_Matches._Mark(_Needle, _Needle + _Needle_size)) { // couldn't put one of the characters into the
// bitmap, fall back to the serial algorithm
return _Traits_find_first_of<_Traits, false>(_Haystack, _Hay_size, _Start_at, _Needle, _Needle_size);
}
if (!_STD _Is_constant_evaluated()) {
using _Elem = typename _Traits::char_type;

#if _USE_STD_VECTOR_ALGORITHMS
const bool _Try_vectorize = _Hay_size - _Start_at > _Threshold_find_first_of;

// Additional condition for when the vectorization outperforms the table lookup
const bool _Use_bitmap = !_Try_vectorize || (sizeof(_Elem) > 1 && sizeof(_Elem) * _Needle_size > 16);
#else
const bool _Use_bitmap = true;
#endif // _USE_STD_VECTOR_ALGORITHMS

if (_Use_bitmap) {
_String_bitmap<_Elem> _Matches;

if (_Matches._Mark(_Needle, _Needle + _Needle_size)) {
for (auto _Match_try = _Hay_start; _Match_try < _Hay_end; ++_Match_try) {
if (_Matches._Match(*_Match_try)) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
}
}
return static_cast<size_t>(-1); // no match
}

// couldn't put one of the characters into the bitmap, fall back to vectorized or serial algorithms
}

const auto _End = _Haystack + _Hay_size;
for (auto _Match_try = _Haystack + _Start_at; _Match_try < _End; ++_Match_try) {
if (_Matches._Match(*_Match_try)) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
#if _USE_STD_VECTOR_ALGORITHMS
if (_Try_vectorize) {
const _Traits_ptr_t<_Traits> _Found =
_STD _Find_first_of_vectorized(_Hay_start, _Hay_end, _Needle, _Needle + _Needle_size);

if (_Found != _Hay_end) {
return static_cast<size_t>(_Found - _Haystack); // found a match
} else {
return static_cast<size_t>(-1); // no match
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
}
} else {
const auto _End = _Haystack + _Hay_size;
for (auto _Match_try = _Haystack + _Start_at; _Match_try < _End; ++_Match_try) {
if (_Traits::find(_Needle, _Needle_size, *_Match_try)) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
}
}

for (auto _Match_try = _Hay_start; _Match_try < _Hay_end; ++_Match_try) {
if (_Traits::find(_Needle, _Needle_size, *_Match_try)) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
}
}
}
Expand Down
33 changes: 0 additions & 33 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,6 @@ const void* __stdcall __std_find_last_trivial_2(const void* _First, const void*
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_4(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
Expand Down Expand Up @@ -198,27 +189,6 @@ _Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val
}
}

template <class _Ty1, class _Ty2>
_Ty1* _Find_first_of_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, _Ty2* const _Last2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_1(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 4) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_4(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 8) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_8(_First1, _Last1, _First2, _Last2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
Expand All @@ -237,9 +207,6 @@ __declspec(noalias) void _Replace_vectorized(
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

// Can we activate the vector algorithms for find_first_of?
template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_find_first_of_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr>;
Expand Down
33 changes: 33 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ const void* __stdcall __std_find_trivial_2(const void* _First, const void* _Last
const void* __stdcall __std_find_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_4(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
Expand Down Expand Up @@ -198,6 +207,30 @@ _Ty* _Find_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noe
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

template <class _Ty1, class _Ty2>
_Ty1* _Find_first_of_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, _Ty2* const _Last2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_1(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 4) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_4(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 8) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_8(_First1, _Last1, _First2, _Last2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty>
_Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
constexpr bool _Signed = is_signed_v<_Ty>;
Expand Down
72 changes: 66 additions & 6 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@

using namespace std;

#pragma warning(disable : 4984) // 'if constexpr' is a C++17 language extension
#ifdef __clang__
#pragma clang diagnostic ignored "-Wc++17-extensions" // constexpr if is a C++17 extension
#endif // __clang__

template <class FwdIt, class T>
ptrdiff_t last_known_good_count(FwdIt first, FwdIt last, T v) {
ptrdiff_t result = 0;
Expand Down Expand Up @@ -246,11 +251,12 @@ void test_case_find_first_of(const vector<T>& input_haystack, const vector<T>& i
#endif // _HAS_CXX20
}

constexpr size_t haystackDataCount = 200;
constexpr size_t needleDataCount = 35;

template <class T>
void test_find_first_of(mt19937_64& gen) {
constexpr size_t haystackDataCount = 200;
constexpr size_t needleDataCount = 35;
using TD = conditional_t<sizeof(T) == 1, int, T>;
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis('a', 'z');
vector<T> input_haystack;
vector<T> input_needle;
Expand Down Expand Up @@ -310,9 +316,7 @@ void test_case_search(const vector<T>& input_haystack, const vector<T>& input_ne

template <class T>
void test_search(mt19937_64& gen) {
constexpr size_t haystackDataCount = 200;
constexpr size_t needleDataCount = 35;
using TD = conditional_t<sizeof(T) == 1, int, T>;
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis('0', '9');
vector<T> input_haystack;
vector<T> input_needle;
Expand Down Expand Up @@ -1024,6 +1028,61 @@ void test_bitset(mt19937_64& gen) {
test_randomized_bitset_base_count<512 - 5, 32 + 10>(gen);
}

template <class T>
void test_case_string_find_first_of(const basic_string<T>& input_haystack, const basic_string<T>& input_needle) {
auto expected_iter = last_known_good_find_first_of(
input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());
auto expected = (expected_iter != input_haystack.end()) ? expected_iter - input_haystack.begin() : ptrdiff_t{-1};
auto actual = static_cast<ptrdiff_t>(input_haystack.find_first_of(input_needle.data(), 0, input_needle.size()));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
basic_string<T> input_needle;
input_haystack.reserve(haystackDataCount);
input_needle.reserve(needleDataCount);

for (;;) {
input_needle.clear();

test_case_string_find_first_of(input_haystack, input_needle);
for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_string_find_first_of(input_haystack, input_needle);
}

if (input_haystack.size() == haystackDataCount) {
break;
}

input_haystack.push_back(static_cast<T>(dis(gen)));
}
}

template <class T>
void test_basic_string(mt19937_64& gen) {
using dis_int_type = conditional_t<is_signed_v<T>, int32_t, uint32_t>;

uniform_int_distribution<dis_int_type> dis_latin('a', 'z');
test_basic_string_dis<T>(gen, dis_latin);
if constexpr (sizeof(T) >= 2) {
uniform_int_distribution<dis_int_type> dis_greek(0x391, 0x3C9);
test_basic_string_dis<T>(gen, dis_greek);
}
}

void test_string(mt19937_64& gen) {
test_basic_string<char>(gen);
test_basic_string<wchar_t>(gen);
#ifdef __cpp_lib_char8_t
test_basic_string<char8_t>(gen);
#endif // __cpp_lib_char8_t
test_basic_string<char16_t>(gen);
test_basic_string<char32_t>(gen);
}

void test_various_containers() {
test_one_container<vector<int>>(); // contiguous, vectorizable
test_one_container<deque<int>>(); // random-access, not vectorizable
Expand Down Expand Up @@ -1096,5 +1155,6 @@ int main() {
test_vector_algorithms(gen);
test_various_containers();
test_bitset(gen);
test_string(gen);
});
}

0 comments on commit 77b31f7

Please sign in to comment.