From d18f26cf3b86afcf2116a0ecb496981c349e77b5 Mon Sep 17 00:00:00 2001 From: ijl Date: Thu, 8 Aug 2024 23:26:47 +0000 Subject: [PATCH] sse2 stable impl --- src/serialize/writer/json.rs | 25 +++++- src/serialize/writer/str/mod.rs | 15 +++- src/serialize/writer/str/scalar.rs | 4 +- src/serialize/writer/str/sse2.rs | 127 +++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 9 deletions(-) create mode 100644 src/serialize/writer/str/sse2.rs diff --git a/src/serialize/writer/json.rs b/src/serialize/writer/json.rs index fe575831..1d61ebec 100644 --- a/src/serialize/writer/json.rs +++ b/src/serialize/writer/json.rs @@ -604,7 +604,7 @@ where unsafe { reserve_str!(writer, value); - let written = format_escaped_str_impl_generic_128( + let written = format_escaped_str_impl_sse2_128( writer.as_mut_buffer_ptr(), value.as_bytes().as_ptr(), value.len(), @@ -631,7 +631,7 @@ where ); writer.set_written(written); } else { - let written = format_escaped_str_impl_generic_128( + let written = format_escaped_str_impl_sse2_128( writer.as_mut_buffer_ptr(), value.as_bytes().as_ptr(), value.len(), @@ -641,7 +641,7 @@ where } } -#[cfg(not(feature = "unstable-simd"))] +#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))] #[inline(always)] fn format_escaped_str(writer: &mut W, value: &str) where @@ -659,6 +659,25 @@ where } } +#[cfg(all(not(feature = "unstable-simd"), target_arch = "x86_64"))] +#[inline(always)] +fn format_escaped_str(writer: &mut W, value: &str) +where + W: ?Sized + io::Write + WriteExt, +{ + unsafe { + reserve_str!(writer, value); + + let written = format_escaped_str_impl_sse2_128( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + + writer.set_written(written); + } +} + #[inline] pub fn to_writer(writer: W, value: &T) -> Result<()> where diff --git a/src/serialize/writer/str/mod.rs b/src/serialize/writer/str/mod.rs index 0836ecba..6b18f066 100644 --- a/src/serialize/writer/str/mod.rs +++ b/src/serialize/writer/str/mod.rs @@ -5,18 +5,25 @@ mod escape; #[macro_use] mod scalar; +#[cfg(target_arch = "x86_64")] +mod sse2; + #[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] mod avx512; #[cfg(feature = "unstable-simd")] mod generic; -#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] -pub use avx512::format_escaped_str_impl_512vl; +#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))] +pub use scalar::format_escaped_str_scalar; #[allow(unused_imports)] #[cfg(feature = "unstable-simd")] pub use generic::format_escaped_str_impl_generic_128; -#[cfg(not(feature = "unstable-simd"))] -pub use scalar::format_escaped_str_scalar; +#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] +pub use avx512::format_escaped_str_impl_512vl; + +#[allow(unused_imports)] +#[cfg(target_arch = "x86_64")] +pub use sse2::format_escaped_str_impl_sse2_128; diff --git a/src/serialize/writer/str/scalar.rs b/src/serialize/writer/str/scalar.rs index f324e098..29150809 100644 --- a/src/serialize/writer/str/scalar.rs +++ b/src/serialize/writer/str/scalar.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) -#[cfg(not(feature = "unstable-simd"))] +#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))] use super::escape::{NEED_ESCAPED, QUOTE_TAB}; macro_rules! impl_format_scalar { @@ -20,7 +20,7 @@ macro_rules! impl_format_scalar { }; } -#[cfg(not(feature = "unstable-simd"))] +#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))] pub unsafe fn format_escaped_str_scalar( odst: *mut u8, value_ptr: *const u8, diff --git a/src/serialize/writer/str/sse2.rs b/src/serialize/writer/str/sse2.rs new file mode 100644 index 00000000..45114e0c --- /dev/null +++ b/src/serialize/writer/str/sse2.rs @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 + +use super::escape::{NEED_ESCAPED, QUOTE_TAB}; + +use core::mem::transmute; + +use core::arch::x86_64::{ + __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128, _mm_set1_epi8, + _mm_setzero_si128, _mm_storeu_si128, _mm_subs_epu8, +}; + +macro_rules! splat_mm128 { + ($val:expr) => { + _mm_set1_epi8(transmute::($val)) + }; +} + +macro_rules! impl_format_simd_sse2_128 { + ($dst:expr, $src:expr, $value_len:expr) => { + let last_stride_src = $src.add($value_len).sub(STRIDE); + let mut nb: usize = $value_len; + + assume!($value_len >= STRIDE); + + let blash = splat_mm128!(b'\\'); + let quote = splat_mm128!(b'"'); + let x20 = splat_mm128!(31); + let v0 = _mm_setzero_si128(); + + unsafe { + while nb >= STRIDE { + let str_vec = _mm_loadu_si128($src as *const __m128i); + + let mask = _mm_movemask_epi8(_mm_or_si128( + _mm_or_si128( + _mm_cmpeq_epi8(str_vec, blash), + _mm_cmpeq_epi8(str_vec, quote), + ), + _mm_cmpeq_epi8(_mm_subs_epu8(str_vec, x20), v0), + )) as u32; + + _mm_storeu_si128($dst as *mut __m128i, str_vec); + + if unlikely!(mask > 0) { + let cn = trailing_zeros!(mask) as usize; + nb -= cn; + $dst = $dst.add(cn); + $src = $src.add(cn); + nb -= 1; + let escape = QUOTE_TAB[*($src) as usize]; + write_escape!(escape, $dst); + $dst = $dst.add(escape.1 as usize); + $src = $src.add(1); + } else { + nb -= STRIDE; + $dst = $dst.add(STRIDE); + $src = $src.add(STRIDE); + } + } + + if nb > 0 { + let mut scratch: [u8; 32] = [b'a'; 32]; + let mut str_vec = _mm_loadu_si128(last_stride_src as *const __m128i); + _mm_storeu_si128(scratch.as_mut_ptr() as *mut __m128i, str_vec); + + let mut scratch_ptr = scratch.as_mut_ptr().add(16 - nb); + str_vec = _mm_loadu_si128(scratch_ptr as *const __m128i); + + let mut mask = _mm_movemask_epi8(_mm_or_si128( + _mm_or_si128( + _mm_cmpeq_epi8(str_vec, blash), + _mm_cmpeq_epi8(str_vec, quote), + ), + _mm_cmpeq_epi8(_mm_subs_epu8(str_vec, x20), v0), + )) as u32; + + while nb > 0 { + _mm_storeu_si128($dst as *mut __m128i, str_vec); + + if unlikely!(mask > 0) { + let cn = trailing_zeros!(mask) as usize; + nb -= cn; + $dst = $dst.add(cn); + scratch_ptr = scratch_ptr.add(cn); + nb -= 1; + mask >>= cn + 1; + let escape = QUOTE_TAB[*(scratch_ptr) as usize]; + write_escape!(escape, $dst); + $dst = $dst.add(escape.1 as usize); + scratch_ptr = scratch_ptr.add(1); + str_vec = _mm_loadu_si128(scratch_ptr as *const __m128i); + } else { + $dst = $dst.add(nb); + break; + } + } + } + } + }; +} + +#[allow(dead_code)] +#[inline(never)] +pub unsafe fn format_escaped_str_impl_sse2_128( + odst: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + const STRIDE: usize = 16; + + let mut dst = odst; + let mut src = value_ptr; + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + if value_len < STRIDE { + impl_format_scalar!(dst, src, value_len) + } else { + impl_format_simd_sse2_128!(dst, src, value_len); + } + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize +}