Skip to content

Commit

Permalink
[mono] simdhash tweaks (#101082)
Browse files Browse the repository at this point in the history
* Optimize simdhash scalar bucket scan to generate chained cmovs instead of conditional branches, and to skip suffix slots that will always be empty
* Optimize temporary copies of hash->buffers out of most simdhash APIs
* Build fixes for ght-compatible simdhash
* Add missing license headers
* Improve simdhash microbenchmark suite and make it compatible with windows x64 MSVC
* Use xoshiro256 to generate random values in microbenchmark, with a fixed seed, instead of libc rand
* Add filtering support to simdhash microbenchmark makefile
  • Loading branch information
kg authored Apr 19, 2024
1 parent d28c577 commit fee6794
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 116 deletions.
69 changes: 7 additions & 62 deletions src/native/containers/dn-simdhash-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,10 @@ build_search_vector (uint8_t needle)

// returns an index in range 0-13 on match, 14-32 if no match
static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix (
find_first_matching_suffix_simd (
dn_simdhash_search_vector needle,
// Only used by the vectorized implementations; discarded by scalar.
dn_simdhash_suffixes haystack,
// HACK: Pass the address of haystack.values directly, for scalar fallback.
// Without this, clang makes a full unaligned copy of haystack before calling us.
// Discarded by the vectorized implementations.
uint8_t haystack_values[DN_SIMDHASH_VECTOR_WIDTH],
uint32_t count
dn_simdhash_suffixes haystack
) {
#if defined(__wasm_simd128__)
return ctz(wasm_i8x16_bitmask(wasm_i8x16_eq(needle.vec, haystack.vec)));
Expand All @@ -123,37 +118,8 @@ find_first_matching_suffix (
msb.b[1] = vaddv_u8(vget_high_u8(masked.vec));
return ctz(msb.u);
#else
// HACK: We can't put this in a common helper function without introducing a temporary
// unaligned copy-from-table-to-stack in wasm-without-simd
#define ITER(offset) \
if (needle == haystack_values[offset]) \
return offset;

// It is safe to unroll this without bounds checks
// One would expect this to blow out the branch predictor, but in my testing
// it's significantly faster when there is no match, and slightly faster
// for cases where there is a match.
// Looping from 0-count is slower than this in my testing, even though it's
// going to check fewer suffixes most of the time - probably due to the
// comparison against count for each suffix.
// FIXME: If we move this into the specialization header, we can limit the
// number of unrolled iterations to the number of keys in the bucket.
ITER(0);
ITER(1);
ITER(2);
ITER(3);
ITER(4);
ITER(5);
ITER(6);
ITER(7);
ITER(8);
ITER(9);
ITER(10);
ITER(11);
ITER(12);
ITER(13);
#undef ITER
return 32;
dn_simdhash_assert(!"Scalar fallback should be in use here");
return 32;
#endif
}

Expand Down Expand Up @@ -194,17 +160,12 @@ build_search_vector (uint8_t needle)

// returns an index in range 0-13 on match, 14-32 if no match
static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix_internal (
__m128i needle, __m128i haystack,
uint32_t count
find_first_matching_suffix_simd (
dn_simdhash_search_vector needle, dn_simdhash_suffixes haystack
) {
return ctz(_mm_movemask_epi8(_mm_cmpeq_epi8(needle, haystack)));
return ctz(_mm_movemask_epi8(_mm_cmpeq_epi8(needle.m128, haystack.m128)));
}

// use a macro to discard haystack_values, otherwise MSVC's codegen is worse
#define find_first_matching_suffix(needle, haystack, haystack_values, count) \
find_first_matching_suffix_internal(needle.m128, haystack.m128, count)

#else // unknown compiler and/or unknown non-simd arch

#define DN_SIMDHASH_USE_SCALAR_FALLBACK 1
Expand All @@ -228,22 +189,6 @@ build_search_vector (uint8_t needle)
return needle;
}

// returns an index in range 0-14 on match, 32 if no match
static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix (
dn_simdhash_search_vector needle, dn_simdhash_suffixes haystack,
uint8_t haystack_values[DN_SIMDHASH_VECTOR_WIDTH], uint32_t count
) {
// TODO: It might be profitable to hand-unroll this loop, but right now doing so
// hits a bug in clang and generates really bad WASM.
// HACK: We can't put this in a common helper function without introducing a temporary
// unaligned copy-from-table-to-stack in wasm-without-simd
for (uint32_t i = 0; i < count; i++)
if (needle == haystack_values[i])
return i;
return 32;
}

#endif // end of clang/gcc or msvc or fallback

#endif // __DN_SIMDHASH_ARCH_H__
1 change: 0 additions & 1 deletion src/native/containers/dn-simdhash-ght-compatible.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ dn_simdhash_ght_replaced (dn_simdhash_ght_data data, void * old_key, void * new_
#define DN_SIMDHASH_NO_DEFAULT_NEW 1

#include "dn-simdhash-specialization.h"
#include "dn-simdhash-ght-compatible.h"

dn_simdhash_ght_t *
dn_simdhash_ght_new (
Expand Down
5 changes: 5 additions & 0 deletions src/native/containers/dn-simdhash-ght-compatible.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

typedef dn_simdhash_t dn_simdhash_ght_t;

typedef void (*dn_simdhash_ght_destroy_func) (void * data);
typedef unsigned int (*dn_simdhash_ght_hash_func) (const void * key);
typedef int32_t (*dn_simdhash_ght_equal_func) (const void * a, const void * b);
Expand Down
105 changes: 69 additions & 36 deletions src/native/containers/dn-simdhash-specialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,48 @@ dn_simdhash_meta_t DN_SIMDHASH_T_META = {
sizeof(DN_SIMDHASH_INSTANCE_DATA_T),
};

static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix_scalar (
uint8_t needle,
uint8_t haystack[DN_SIMDHASH_VECTOR_WIDTH]
) {
uint32_t result = 32;
// ITERs for indices beyond our specialization's bucket capacity will be
// constant-false and not check the specific bucket slot
#define ITER(offset) \
{ \
/* Avoid MSVC C4127 by computing this separately in a temp local */ \
uint8_t in_bounds = (offset < DN_SIMDHASH_BUCKET_CAPACITY); \
if (in_bounds && (needle == haystack[offset])) \
result = offset; \
}

// It is safe to unroll this without bounds checks
// Looping from 0-count is slower than this in my testing, even though it's
// going to check fewer suffixes most of the time - probably due to the
// comparison against count for each suffix.
// Scanning in reverse and conditionally modifying result allows clang to
// emit a chain of 'select' operations per slot on wasm, which produces
// smaller code that seems to be much faster than a chain of
// 'if (...) return' for successful matches, and only slightly slower
// for failed matches
ITER(13);
ITER(12);
ITER(11);
ITER(10);
ITER(9);
ITER(8);
ITER(7);
ITER(6);
ITER(5);
ITER(4);
ITER(3);
ITER(2);
ITER(1);
ITER(0);
#undef ITER
return result;
}

static DN_FORCEINLINE(void)
check_self (DN_SIMDHASH_T_PTR self)
Expand Down Expand Up @@ -152,7 +194,12 @@ DN_SIMDHASH_SCAN_BUCKET_INTERNAL (DN_SIMDHASH_T_PTR hash, bucket_t *restrict buc
overflow_count = dn_simdhash_extract_lane(bucket_suffixes, DN_SIMDHASH_CASCADED_SLOT);
// We could early-out here when count==0, but it doesn't appear to meaningfully improve
// search performance to do so, and might actually worsen it
uint32_t index = find_first_matching_suffix(search_vector, bucket_suffixes, bucket_suffixes.values, count);
#ifdef DN_SIMDHASH_USE_SCALAR_FALLBACK
uint32_t index = find_first_matching_suffix_scalar(search_vector, bucket->suffixes.values);
#else
uint32_t index = find_first_matching_suffix_simd(search_vector, bucket_suffixes);
#endif
#undef bucket_suffixes
for (; index < count; index++) {
// FIXME: Could be profitable to manually hoist the data load outside of the loop,
// if not out of SCAN_BUCKET_INTERNAL entirely. Clang appears to do LICM on it.
Expand All @@ -164,22 +211,20 @@ DN_SIMDHASH_SCAN_BUCKET_INTERNAL (DN_SIMDHASH_T_PTR hash, bucket_t *restrict buc
return index;
}

#undef bucket_suffixes

if (overflow_count)
return DN_SIMDHASH_SCAN_BUCKET_OVERFLOWED;
else
return DN_SIMDHASH_SCAN_BUCKET_NO_OVERFLOW;
}

// Helper macros so that we can optimize and change scan logic more easily
#define BEGIN_SCAN_BUCKETS(initial_index, bucket_index, bucket_address) \
#define BEGIN_SCAN_BUCKETS(buffers, initial_index, bucket_index, bucket_address) \
{ \
uint32_t bucket_index = initial_index, scan_buckets_length = buffers.buckets_length; \
bucket_t *restrict bucket_address = address_of_bucket(buffers, bucket_index); \
do {

#define END_SCAN_BUCKETS(initial_index, bucket_index, bucket_address) \
#define END_SCAN_BUCKETS(buffers, initial_index, bucket_index, bucket_address) \
bucket_index++; \
bucket_address++; \
/* Wrap around if we hit the last bucket. */ \
Expand Down Expand Up @@ -211,7 +256,7 @@ DN_SIMDHASH_SCAN_BUCKET_INTERNAL (DN_SIMDHASH_T_PTR hash, bucket_t *restrict buc
static void
adjust_cascaded_counts (dn_simdhash_buffers_t buffers, uint32_t first_bucket_index, uint32_t last_bucket_index, uint8_t increase)
{
BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(buffers, first_bucket_index, bucket_index, bucket_address)
if (bucket_index == last_bucket_index)
break;

Expand All @@ -224,26 +269,25 @@ adjust_cascaded_counts (dn_simdhash_buffers_t buffers, uint32_t first_bucket_ind
dn_simdhash_bucket_set_cascaded_count(bucket_address->suffixes, cascaded_count - 1);
}
}
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(buffers, first_bucket_index, bucket_index, bucket_address)
}

static DN_SIMDHASH_VALUE_T *
static DN_FORCEINLINE(DN_SIMDHASH_VALUE_T *)
DN_SIMDHASH_FIND_VALUE_INTERNAL (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key, uint32_t key_hash)
{
dn_simdhash_buffers_t buffers = hash->buffers;
uint8_t suffix = dn_simdhash_select_suffix(key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(buffers, key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(hash->buffers, key_hash);
dn_simdhash_search_vector search_vector = build_search_vector(suffix);

BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)
int index_in_bucket = DN_SIMDHASH_SCAN_BUCKET_INTERNAL(hash, bucket_address, key, search_vector);
if (index_in_bucket >= 0) {
uint32_t value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + index_in_bucket;
return address_of_value(buffers, value_slot_index);
return address_of_value(hash->buffers, value_slot_index);
} else if (index_in_bucket == DN_SIMDHASH_SCAN_BUCKET_NO_OVERFLOW) {
return NULL;
}
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)

return NULL;
}
Expand Down Expand Up @@ -288,12 +332,11 @@ DN_SIMDHASH_TRY_INSERT_INTERNAL (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
return DN_SIMDHASH_INSERT_NEED_TO_GROW;
}

dn_simdhash_buffers_t buffers = hash->buffers;
uint8_t suffix = dn_simdhash_select_suffix(key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(hash->buffers, key_hash);
dn_simdhash_search_vector search_vector = build_search_vector(suffix);

BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)
// If necessary, check the current bucket for the key
if (mode != DN_SIMDHASH_INSERT_MODE_REHASHING) {
int index_in_bucket = DN_SIMDHASH_SCAN_BUCKET_INTERNAL(hash, bucket_address, key, search_vector);
Expand Down Expand Up @@ -325,19 +368,19 @@ DN_SIMDHASH_TRY_INSERT_INTERNAL (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
*key_slot_address = key;
// Now store the value, it's in a different cache line
uint32_t value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + new_index;
DN_SIMDHASH_VALUE_T *restrict value_slot_address = address_of_value(buffers, value_slot_index);
DN_SIMDHASH_VALUE_T *restrict value_slot_address = address_of_value(hash->buffers, value_slot_index);
*value_slot_address = value;
// printf("Inserted [%zd, %zd] in bucket %d at index %d\n", key, value, bucket_index, new_index);
// If we cascaded out of our original target bucket, scan through our probe path
// and increase the cascade counters. We have to wait until now to do that, because
// during the process of getting here we may end up finding a duplicate, which would
// leave the cascade counters in a corrupted state
adjust_cascaded_counts(buffers, first_bucket_index, bucket_index, 1);
adjust_cascaded_counts(hash->buffers, first_bucket_index, bucket_index, 1);
return DN_SIMDHASH_INSERT_OK_ADDED_NEW;
}

// The current bucket is full, so try the next bucket.
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)

return DN_SIMDHASH_INSERT_NEED_TO_GROW;
}
Expand All @@ -362,10 +405,9 @@ DN_SIMDHASH_REHASH_INTERNAL (DN_SIMDHASH_T_PTR hash, dn_simdhash_buffers_t old_b
static void
DN_SIMDHASH_DESTROY_ALL (DN_SIMDHASH_T_PTR hash)
{
dn_simdhash_buffers_t buffers = hash->buffers;
BEGIN_SCAN_PAIRS(buffers, key_address, value_address)
BEGIN_SCAN_PAIRS(hash->buffers, key_address, value_address)
DN_SIMDHASH_ON_REMOVE(DN_SIMDHASH_GET_DATA(hash), *key_address, *value_address);
END_SCAN_PAIRS(buffers, key_address, value_address)
END_SCAN_PAIRS(hash->buffers, key_address, value_address)
}
#endif

Expand All @@ -385,14 +427,6 @@ dn_simdhash_vtable_t DN_SIMDHASH_T_VTABLE = {
DN_SIMDHASH_T_PTR
DN_SIMDHASH_NEW (uint32_t capacity, dn_allocator_t *allocator)
{
// If this isn't satisfied, the generic code will allocate incorrectly sized buffers
// HACK: Use static_assert because for some reason assert produces unused variable warnings only on CI
struct silence_nuisance_msvc_warning { bucket_t a, b; };
static_assert(
sizeof(struct silence_nuisance_msvc_warning) == (sizeof(bucket_t) * 2),
"Inconsistent spacing/sizing for bucket_t"
);

return dn_simdhash_new_internal(&DN_SIMDHASH_T_META, DN_SIMDHASH_T_VTABLE, capacity, allocator);
}
#endif
Expand Down Expand Up @@ -475,12 +509,11 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
{
check_self(hash);

dn_simdhash_buffers_t buffers = hash->buffers;
uint8_t suffix = dn_simdhash_select_suffix(key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(buffers, key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(hash->buffers, key_hash);
dn_simdhash_search_vector search_vector = build_search_vector(suffix);

BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)
int index_in_bucket = DN_SIMDHASH_SCAN_BUCKET_INTERNAL(hash, bucket_address, key, search_vector);
if (index_in_bucket >= 0) {
// We found the item. Replace it with the last item in the bucket, then erase
Expand All @@ -490,8 +523,8 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
uint32_t value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + index_in_bucket,
replacement_value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + replacement_index_in_bucket;

DN_SIMDHASH_VALUE_T *value_address = address_of_value(buffers, value_slot_index);
DN_SIMDHASH_VALUE_T *replacement_address = address_of_value(buffers, replacement_value_slot_index);
DN_SIMDHASH_VALUE_T *value_address = address_of_value(hash->buffers, value_slot_index);
DN_SIMDHASH_VALUE_T *replacement_address = address_of_value(hash->buffers, replacement_value_slot_index);
DN_SIMDHASH_KEY_T *key_address = &bucket_address->keys[index_in_bucket];
DN_SIMDHASH_KEY_T *replacement_key_address = &bucket_address->keys[replacement_index_in_bucket];

Expand Down Expand Up @@ -529,7 +562,7 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
// to go through all the buckets we visited on the way here and reduce
// their cascade counters (if possible), to maintain better scan performance.
if (bucket_index != first_bucket_index)
adjust_cascaded_counts(buffers, first_bucket_index, bucket_index, 0);
adjust_cascaded_counts(hash->buffers, first_bucket_index, bucket_index, 0);

#if DN_SIMDHASH_HAS_REMOVE_HANDLER
// We've finished removing the item, so we're in a consistent state and can notify
Expand All @@ -539,7 +572,7 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
return 1;
} else if (index_in_bucket == DN_SIMDHASH_SCAN_BUCKET_NO_OVERFLOW)
return 0;
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)

return 0;
}
Expand Down
9 changes: 3 additions & 6 deletions src/native/containers/dn-simdhash.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,9 @@ dn_simdhash_select_suffix (uint32_t key_hash)
return (key_hash >> 24) | DN_SIMDHASH_SUFFIX_SALT;
}

static DN_FORCEINLINE(uint32_t)
dn_simdhash_select_bucket_index (dn_simdhash_buffers_t buffers, uint32_t key_hash)
{
// This relies on bucket count being a power of two.
return key_hash & (buffers.buckets_length - 1);
}
// This relies on bucket count being a power of two.
#define dn_simdhash_select_bucket_index(buffers, key_hash) \
((key_hash) & ((buffers).buckets_length - 1))


// Creates a simdhash with the provided configuration metadata, vtable, size, and allocator.
Expand Down
Loading

0 comments on commit fee6794

Please sign in to comment.