Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mono] simdhash tweaks #101082

Merged
merged 9 commits into from
Apr 19, 2024
69 changes: 7 additions & 62 deletions src/native/containers/dn-simdhash-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,10 @@ build_search_vector (uint8_t needle)

// returns an index in range 0-13 on match, 14-32 if no match
static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix (
find_first_matching_suffix_simd (
dn_simdhash_search_vector needle,
// Only used by the vectorized implementations; discarded by scalar.
dn_simdhash_suffixes haystack,
// HACK: Pass the address of haystack.values directly, for scalar fallback.
// Without this, clang makes a full unaligned copy of haystack before calling us.
// Discarded by the vectorized implementations.
uint8_t haystack_values[DN_SIMDHASH_VECTOR_WIDTH],
uint32_t count
dn_simdhash_suffixes haystack
) {
#if defined(__wasm_simd128__)
return ctz(wasm_i8x16_bitmask(wasm_i8x16_eq(needle.vec, haystack.vec)));
Expand All @@ -123,37 +118,8 @@ find_first_matching_suffix (
msb.b[1] = vaddv_u8(vget_high_u8(masked.vec));
return ctz(msb.u);
#else
// HACK: We can't put this in a common helper function without introducing a temporary
// unaligned copy-from-table-to-stack in wasm-without-simd
#define ITER(offset) \
if (needle == haystack_values[offset]) \
return offset;

// It is safe to unroll this without bounds checks
// One would expect this to blow out the branch predictor, but in my testing
// it's significantly faster when there is no match, and slightly faster
// for cases where there is a match.
// Looping from 0-count is slower than this in my testing, even though it's
// going to check fewer suffixes most of the time - probably due to the
// comparison against count for each suffix.
// FIXME: If we move this into the specialization header, we can limit the
// number of unrolled iterations to the number of keys in the bucket.
ITER(0);
ITER(1);
ITER(2);
ITER(3);
ITER(4);
ITER(5);
ITER(6);
ITER(7);
ITER(8);
ITER(9);
ITER(10);
ITER(11);
ITER(12);
ITER(13);
#undef ITER
return 32;
dn_simdhash_assert(!"Scalar fallback should be in use here");
return 32;
#endif
}

Expand Down Expand Up @@ -194,17 +160,12 @@ build_search_vector (uint8_t needle)

// returns an index in range 0-13 on match, 14-32 if no match
static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix_internal (
__m128i needle, __m128i haystack,
uint32_t count
find_first_matching_suffix_simd (
dn_simdhash_search_vector needle, dn_simdhash_suffixes haystack
) {
return ctz(_mm_movemask_epi8(_mm_cmpeq_epi8(needle, haystack)));
return ctz(_mm_movemask_epi8(_mm_cmpeq_epi8(needle.m128, haystack.m128)));
}

// use a macro to discard haystack_values, otherwise MSVC's codegen is worse
#define find_first_matching_suffix(needle, haystack, haystack_values, count) \
find_first_matching_suffix_internal(needle.m128, haystack.m128, count)

#else // unknown compiler and/or unknown non-simd arch

#define DN_SIMDHASH_USE_SCALAR_FALLBACK 1
Expand All @@ -228,22 +189,6 @@ build_search_vector (uint8_t needle)
return needle;
}

// returns an index in range 0-14 on match, 32 if no match
static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix (
dn_simdhash_search_vector needle, dn_simdhash_suffixes haystack,
uint8_t haystack_values[DN_SIMDHASH_VECTOR_WIDTH], uint32_t count
) {
// TODO: It might be profitable to hand-unroll this loop, but right now doing so
// hits a bug in clang and generates really bad WASM.
// HACK: We can't put this in a common helper function without introducing a temporary
// unaligned copy-from-table-to-stack in wasm-without-simd
for (uint32_t i = 0; i < count; i++)
if (needle == haystack_values[i])
return i;
return 32;
}

#endif // end of clang/gcc or msvc or fallback

#endif // __DN_SIMDHASH_ARCH_H__
1 change: 0 additions & 1 deletion src/native/containers/dn-simdhash-ght-compatible.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ dn_simdhash_ght_replaced (dn_simdhash_ght_data data, void * old_key, void * new_
#define DN_SIMDHASH_NO_DEFAULT_NEW 1

#include "dn-simdhash-specialization.h"
#include "dn-simdhash-ght-compatible.h"

dn_simdhash_ght_t *
dn_simdhash_ght_new (
Expand Down
5 changes: 5 additions & 0 deletions src/native/containers/dn-simdhash-ght-compatible.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

typedef dn_simdhash_t dn_simdhash_ght_t;

typedef void (*dn_simdhash_ght_destroy_func) (void * data);
typedef unsigned int (*dn_simdhash_ght_hash_func) (const void * key);
typedef int32_t (*dn_simdhash_ght_equal_func) (const void * a, const void * b);
Expand Down
105 changes: 69 additions & 36 deletions src/native/containers/dn-simdhash-specialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,48 @@ dn_simdhash_meta_t DN_SIMDHASH_T_META = {
sizeof(DN_SIMDHASH_INSTANCE_DATA_T),
};

static DN_FORCEINLINE(uint32_t)
find_first_matching_suffix_scalar (
uint8_t needle,
uint8_t haystack[DN_SIMDHASH_VECTOR_WIDTH]
) {
uint32_t result = 32;
// ITERs for indices beyond our specialization's bucket capacity will be
// constant-false and not check the specific bucket slot
#define ITER(offset) \
{ \
/* Avoid MSVC C4127 by computing this separately in a temp local */ \
uint8_t in_bounds = (offset < DN_SIMDHASH_BUCKET_CAPACITY); \
if (in_bounds && (needle == haystack[offset])) \
result = offset; \
}

// It is safe to unroll this without bounds checks
// Looping from 0-count is slower than this in my testing, even though it's
// going to check fewer suffixes most of the time - probably due to the
// comparison against count for each suffix.
// Scanning in reverse and conditionally modifying result allows clang to
// emit a chain of 'select' operations per slot on wasm, which produces
// smaller code that seems to be much faster than a chain of
// 'if (...) return' for successful matches, and only slightly slower
// for failed matches
ITER(13);
ITER(12);
ITER(11);
ITER(10);
ITER(9);
ITER(8);
ITER(7);
ITER(6);
ITER(5);
ITER(4);
ITER(3);
ITER(2);
ITER(1);
ITER(0);
#undef ITER
return result;
}

static DN_FORCEINLINE(void)
check_self (DN_SIMDHASH_T_PTR self)
Expand Down Expand Up @@ -152,7 +194,12 @@ DN_SIMDHASH_SCAN_BUCKET_INTERNAL (DN_SIMDHASH_T_PTR hash, bucket_t *restrict buc
overflow_count = dn_simdhash_extract_lane(bucket_suffixes, DN_SIMDHASH_CASCADED_SLOT);
// We could early-out here when count==0, but it doesn't appear to meaningfully improve
// search performance to do so, and might actually worsen it
uint32_t index = find_first_matching_suffix(search_vector, bucket_suffixes, bucket_suffixes.values, count);
#ifdef DN_SIMDHASH_USE_SCALAR_FALLBACK
uint32_t index = find_first_matching_suffix_scalar(search_vector, bucket->suffixes.values);
#else
uint32_t index = find_first_matching_suffix_simd(search_vector, bucket_suffixes);
#endif
#undef bucket_suffixes
for (; index < count; index++) {
// FIXME: Could be profitable to manually hoist the data load outside of the loop,
// if not out of SCAN_BUCKET_INTERNAL entirely. Clang appears to do LICM on it.
Expand All @@ -164,22 +211,20 @@ DN_SIMDHASH_SCAN_BUCKET_INTERNAL (DN_SIMDHASH_T_PTR hash, bucket_t *restrict buc
return index;
}

#undef bucket_suffixes

if (overflow_count)
return DN_SIMDHASH_SCAN_BUCKET_OVERFLOWED;
else
return DN_SIMDHASH_SCAN_BUCKET_NO_OVERFLOW;
}

// Helper macros so that we can optimize and change scan logic more easily
#define BEGIN_SCAN_BUCKETS(initial_index, bucket_index, bucket_address) \
#define BEGIN_SCAN_BUCKETS(buffers, initial_index, bucket_index, bucket_address) \
{ \
uint32_t bucket_index = initial_index, scan_buckets_length = buffers.buckets_length; \
bucket_t *restrict bucket_address = address_of_bucket(buffers, bucket_index); \
do {

#define END_SCAN_BUCKETS(initial_index, bucket_index, bucket_address) \
#define END_SCAN_BUCKETS(buffers, initial_index, bucket_index, bucket_address) \
bucket_index++; \
bucket_address++; \
/* Wrap around if we hit the last bucket. */ \
Expand Down Expand Up @@ -211,7 +256,7 @@ DN_SIMDHASH_SCAN_BUCKET_INTERNAL (DN_SIMDHASH_T_PTR hash, bucket_t *restrict buc
static void
adjust_cascaded_counts (dn_simdhash_buffers_t buffers, uint32_t first_bucket_index, uint32_t last_bucket_index, uint8_t increase)
{
BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(buffers, first_bucket_index, bucket_index, bucket_address)
if (bucket_index == last_bucket_index)
break;

Expand All @@ -224,26 +269,25 @@ adjust_cascaded_counts (dn_simdhash_buffers_t buffers, uint32_t first_bucket_ind
dn_simdhash_bucket_set_cascaded_count(bucket_address->suffixes, cascaded_count - 1);
}
}
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(buffers, first_bucket_index, bucket_index, bucket_address)
}

static DN_SIMDHASH_VALUE_T *
static DN_FORCEINLINE(DN_SIMDHASH_VALUE_T *)
DN_SIMDHASH_FIND_VALUE_INTERNAL (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key, uint32_t key_hash)
{
dn_simdhash_buffers_t buffers = hash->buffers;
uint8_t suffix = dn_simdhash_select_suffix(key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(buffers, key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(hash->buffers, key_hash);
dn_simdhash_search_vector search_vector = build_search_vector(suffix);

BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)
int index_in_bucket = DN_SIMDHASH_SCAN_BUCKET_INTERNAL(hash, bucket_address, key, search_vector);
if (index_in_bucket >= 0) {
uint32_t value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + index_in_bucket;
return address_of_value(buffers, value_slot_index);
return address_of_value(hash->buffers, value_slot_index);
} else if (index_in_bucket == DN_SIMDHASH_SCAN_BUCKET_NO_OVERFLOW) {
return NULL;
}
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)

return NULL;
}
Expand Down Expand Up @@ -288,12 +332,11 @@ DN_SIMDHASH_TRY_INSERT_INTERNAL (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
return DN_SIMDHASH_INSERT_NEED_TO_GROW;
}

dn_simdhash_buffers_t buffers = hash->buffers;
uint8_t suffix = dn_simdhash_select_suffix(key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(hash->buffers, key_hash);
dn_simdhash_search_vector search_vector = build_search_vector(suffix);

BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)
// If necessary, check the current bucket for the key
if (mode != DN_SIMDHASH_INSERT_MODE_REHASHING) {
int index_in_bucket = DN_SIMDHASH_SCAN_BUCKET_INTERNAL(hash, bucket_address, key, search_vector);
Expand Down Expand Up @@ -325,19 +368,19 @@ DN_SIMDHASH_TRY_INSERT_INTERNAL (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
*key_slot_address = key;
// Now store the value, it's in a different cache line
uint32_t value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + new_index;
DN_SIMDHASH_VALUE_T *restrict value_slot_address = address_of_value(buffers, value_slot_index);
DN_SIMDHASH_VALUE_T *restrict value_slot_address = address_of_value(hash->buffers, value_slot_index);
*value_slot_address = value;
// printf("Inserted [%zd, %zd] in bucket %d at index %d\n", key, value, bucket_index, new_index);
// If we cascaded out of our original target bucket, scan through our probe path
// and increase the cascade counters. We have to wait until now to do that, because
// during the process of getting here we may end up finding a duplicate, which would
// leave the cascade counters in a corrupted state
adjust_cascaded_counts(buffers, first_bucket_index, bucket_index, 1);
adjust_cascaded_counts(hash->buffers, first_bucket_index, bucket_index, 1);
return DN_SIMDHASH_INSERT_OK_ADDED_NEW;
}

// The current bucket is full, so try the next bucket.
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)

return DN_SIMDHASH_INSERT_NEED_TO_GROW;
}
Expand All @@ -362,10 +405,9 @@ DN_SIMDHASH_REHASH_INTERNAL (DN_SIMDHASH_T_PTR hash, dn_simdhash_buffers_t old_b
static void
DN_SIMDHASH_DESTROY_ALL (DN_SIMDHASH_T_PTR hash)
{
dn_simdhash_buffers_t buffers = hash->buffers;
BEGIN_SCAN_PAIRS(buffers, key_address, value_address)
BEGIN_SCAN_PAIRS(hash->buffers, key_address, value_address)
DN_SIMDHASH_ON_REMOVE(DN_SIMDHASH_GET_DATA(hash), *key_address, *value_address);
END_SCAN_PAIRS(buffers, key_address, value_address)
END_SCAN_PAIRS(hash->buffers, key_address, value_address)
}
#endif

Expand All @@ -385,14 +427,6 @@ dn_simdhash_vtable_t DN_SIMDHASH_T_VTABLE = {
DN_SIMDHASH_T_PTR
DN_SIMDHASH_NEW (uint32_t capacity, dn_allocator_t *allocator)
{
// If this isn't satisfied, the generic code will allocate incorrectly sized buffers
// HACK: Use static_assert because for some reason assert produces unused variable warnings only on CI
struct silence_nuisance_msvc_warning { bucket_t a, b; };
static_assert(
sizeof(struct silence_nuisance_msvc_warning) == (sizeof(bucket_t) * 2),
"Inconsistent spacing/sizing for bucket_t"
);

return dn_simdhash_new_internal(&DN_SIMDHASH_T_META, DN_SIMDHASH_T_VTABLE, capacity, allocator);
}
#endif
Expand Down Expand Up @@ -475,12 +509,11 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
{
check_self(hash);

dn_simdhash_buffers_t buffers = hash->buffers;
uint8_t suffix = dn_simdhash_select_suffix(key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(buffers, key_hash);
uint32_t first_bucket_index = dn_simdhash_select_bucket_index(hash->buffers, key_hash);
dn_simdhash_search_vector search_vector = build_search_vector(suffix);

BEGIN_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
BEGIN_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)
int index_in_bucket = DN_SIMDHASH_SCAN_BUCKET_INTERNAL(hash, bucket_address, key, search_vector);
if (index_in_bucket >= 0) {
// We found the item. Replace it with the last item in the bucket, then erase
Expand All @@ -490,8 +523,8 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
uint32_t value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + index_in_bucket,
replacement_value_slot_index = (bucket_index * DN_SIMDHASH_BUCKET_CAPACITY) + replacement_index_in_bucket;

DN_SIMDHASH_VALUE_T *value_address = address_of_value(buffers, value_slot_index);
DN_SIMDHASH_VALUE_T *replacement_address = address_of_value(buffers, replacement_value_slot_index);
DN_SIMDHASH_VALUE_T *value_address = address_of_value(hash->buffers, value_slot_index);
DN_SIMDHASH_VALUE_T *replacement_address = address_of_value(hash->buffers, replacement_value_slot_index);
DN_SIMDHASH_KEY_T *key_address = &bucket_address->keys[index_in_bucket];
DN_SIMDHASH_KEY_T *replacement_key_address = &bucket_address->keys[replacement_index_in_bucket];

Expand Down Expand Up @@ -529,7 +562,7 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
// to go through all the buckets we visited on the way here and reduce
// their cascade counters (if possible), to maintain better scan performance.
if (bucket_index != first_bucket_index)
adjust_cascaded_counts(buffers, first_bucket_index, bucket_index, 0);
adjust_cascaded_counts(hash->buffers, first_bucket_index, bucket_index, 0);

#if DN_SIMDHASH_HAS_REMOVE_HANDLER
// We've finished removing the item, so we're in a consistent state and can notify
Expand All @@ -539,7 +572,7 @@ DN_SIMDHASH_TRY_REMOVE_WITH_HASH (DN_SIMDHASH_T_PTR hash, DN_SIMDHASH_KEY_T key,
return 1;
} else if (index_in_bucket == DN_SIMDHASH_SCAN_BUCKET_NO_OVERFLOW)
return 0;
END_SCAN_BUCKETS(first_bucket_index, bucket_index, bucket_address)
END_SCAN_BUCKETS(hash->buffers, first_bucket_index, bucket_index, bucket_address)

return 0;
}
Expand Down
9 changes: 3 additions & 6 deletions src/native/containers/dn-simdhash.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,9 @@ dn_simdhash_select_suffix (uint32_t key_hash)
return (key_hash >> 24) | DN_SIMDHASH_SUFFIX_SALT;
}

static DN_FORCEINLINE(uint32_t)
dn_simdhash_select_bucket_index (dn_simdhash_buffers_t buffers, uint32_t key_hash)
{
// This relies on bucket count being a power of two.
return key_hash & (buffers.buckets_length - 1);
}
// This relies on bucket count being a power of two.
#define dn_simdhash_select_bucket_index(buffers, key_hash) \
((key_hash) & ((buffers).buckets_length - 1))


// Creates a simdhash with the provided configuration metadata, vtable, size, and allocator.
Expand Down
Loading
Loading