diff --git a/src/lib/lwan-websocket.c b/src/lib/lwan-websocket.c index ab4f4253c..b83c6950f 100644 --- a/src/lib/lwan-websocket.c +++ b/src/lib/lwan-websocket.c @@ -150,44 +150,40 @@ static size_t get_frame_length(struct lwan_request *request, uint16_t header) static void unmask(char *msg, size_t msg_len, char mask[static 4]) { - const uint32_t mask32 = string_as_uint32(mask); - char *msg_end = msg + msg_len; - - if (sizeof(void *) == 8) { - const uint64_t mask64 = (uint64_t)mask32 << 32 | mask32; + const int32_t mask32 = (int32_t)string_as_uint32(mask); + const char *msg_end = msg + msg_len; #if defined(__AVX2__) - const size_t len256 = msg_len / 32; - if (len256) { - const __m256i mask256 = - _mm256_setr_epi64x((int64_t)mask64, (int64_t)mask64, - (int64_t)mask64, (int64_t)mask64); - for (size_t i = 0; i < len256; i++) { - __m256i v = _mm256_loadu_si256((__m256i *)msg); - _mm256_storeu_si256((__m256i *)msg, - _mm256_xor_si256(v, mask256)); - msg += 32; - } - - msg_len = (size_t)(msg_end - msg); + const size_t len256 = msg_len / 32; + if (len256) { + const __m256i mask256 = _mm256_setr_epi32( + mask32, mask32, mask32, mask32, mask32, mask32, mask32, mask32); + for (size_t i = 0; i < len256; i++) { + __m256i v = _mm256_loadu_si256((__m256i *)msg); + _mm256_storeu_si256((__m256i *)msg, _mm256_xor_si256(v, mask256)); + msg += 32; } + + msg_len = (size_t)(msg_end - msg); + } #endif #if defined(__SSE2__) - const size_t len128 = msg_len / 16; - if (len128) { - const __m128i mask128 = - _mm_setr_epi64((__m64)mask64, (__m64)mask64); - for (size_t i = 0; i < len128; i++) { - __m128i v = _mm_loadu_si128((__m128i *)msg); - _mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128)); - msg += 16; - } - - msg_len = (size_t)(msg_end - msg); + const size_t len128 = msg_len / 16; + if (len128) { + const __m128i mask128 = _mm_setr_epi32(mask32, mask32, mask32, mask32); + for (size_t i = 0; i < len128; i++) { + __m128i v = _mm_loadu_si128((__m128i *)msg); + _mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128)); + msg += 16; } + + msg_len = (size_t)(msg_end - msg); + } #endif + if (sizeof(void *) == 8) { + const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32; const size_t len64 = msg_len / 8; for (size_t i = 0; i < len64; i++) { uint64_t v = string_as_uint64(msg); @@ -199,7 +195,7 @@ static void unmask(char *msg, size_t msg_len, char mask[static 4]) const size_t len32 = (size_t)((msg_end - msg) / 4); for (size_t i = 0; i < len32; i++) { uint32_t v = string_as_uint32(msg); - v ^= mask32; + v ^= (uint32_t)mask32; msg = mempcpy(msg, &v, sizeof(v)); }