Skip to content

Commit

Permalink
sse2 stable impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Aug 8, 2024
1 parent bd9dca5 commit d18f26c
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 9 deletions.
25 changes: 22 additions & 3 deletions src/serialize/writer/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ where
unsafe {
reserve_str!(writer, value);

let written = format_escaped_str_impl_generic_128(
let written = format_escaped_str_impl_sse2_128(
writer.as_mut_buffer_ptr(),
value.as_bytes().as_ptr(),
value.len(),
Expand All @@ -631,7 +631,7 @@ where
);
writer.set_written(written);
} else {
let written = format_escaped_str_impl_generic_128(
let written = format_escaped_str_impl_sse2_128(
writer.as_mut_buffer_ptr(),
value.as_bytes().as_ptr(),
value.len(),
Expand All @@ -641,7 +641,7 @@ where
}
}

#[cfg(not(feature = "unstable-simd"))]
#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))]
#[inline(always)]
fn format_escaped_str<W>(writer: &mut W, value: &str)
where
Expand All @@ -659,6 +659,25 @@ where
}
}

#[cfg(all(not(feature = "unstable-simd"), target_arch = "x86_64"))]
#[inline(always)]
fn format_escaped_str<W>(writer: &mut W, value: &str)
where
W: ?Sized + io::Write + WriteExt,
{
unsafe {
reserve_str!(writer, value);

let written = format_escaped_str_impl_sse2_128(
writer.as_mut_buffer_ptr(),
value.as_bytes().as_ptr(),
value.len(),
);

writer.set_written(written);
}
}

#[inline]
pub fn to_writer<W, T>(writer: W, value: &T) -> Result<()>
where
Expand Down
15 changes: 11 additions & 4 deletions src/serialize/writer/str/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,25 @@ mod escape;
#[macro_use]
mod scalar;

#[cfg(target_arch = "x86_64")]
mod sse2;

#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))]
mod avx512;

#[cfg(feature = "unstable-simd")]
mod generic;

#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))]
pub use avx512::format_escaped_str_impl_512vl;
#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))]
pub use scalar::format_escaped_str_scalar;

#[allow(unused_imports)]
#[cfg(feature = "unstable-simd")]
pub use generic::format_escaped_str_impl_generic_128;

#[cfg(not(feature = "unstable-simd"))]
pub use scalar::format_escaped_str_scalar;
#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))]
pub use avx512::format_escaped_str_impl_512vl;

#[allow(unused_imports)]
#[cfg(target_arch = "x86_64")]
pub use sse2::format_escaped_str_impl_sse2_128;
4 changes: 2 additions & 2 deletions src/serialize/writer/str/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

#[cfg(not(feature = "unstable-simd"))]
#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))]
use super::escape::{NEED_ESCAPED, QUOTE_TAB};

macro_rules! impl_format_scalar {
Expand All @@ -20,7 +20,7 @@ macro_rules! impl_format_scalar {
};
}

#[cfg(not(feature = "unstable-simd"))]
#[cfg(all(not(feature = "unstable-simd"), not(target_arch = "x86_64")))]
pub unsafe fn format_escaped_str_scalar(
odst: *mut u8,
value_ptr: *const u8,
Expand Down
127 changes: 127 additions & 0 deletions src/serialize/writer/str/sse2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// SPDX-License-Identifier: Apache-2.0

use super::escape::{NEED_ESCAPED, QUOTE_TAB};

use core::mem::transmute;

use core::arch::x86_64::{
__m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128, _mm_set1_epi8,
_mm_setzero_si128, _mm_storeu_si128, _mm_subs_epu8,
};

macro_rules! splat_mm128 {
($val:expr) => {
_mm_set1_epi8(transmute::<u8, i8>($val))
};
}

macro_rules! impl_format_simd_sse2_128 {
($dst:expr, $src:expr, $value_len:expr) => {
let last_stride_src = $src.add($value_len).sub(STRIDE);
let mut nb: usize = $value_len;

assume!($value_len >= STRIDE);

let blash = splat_mm128!(b'\\');
let quote = splat_mm128!(b'"');
let x20 = splat_mm128!(31);
let v0 = _mm_setzero_si128();

unsafe {
while nb >= STRIDE {
let str_vec = _mm_loadu_si128($src as *const __m128i);

let mask = _mm_movemask_epi8(_mm_or_si128(
_mm_or_si128(
_mm_cmpeq_epi8(str_vec, blash),
_mm_cmpeq_epi8(str_vec, quote),
),
_mm_cmpeq_epi8(_mm_subs_epu8(str_vec, x20), v0),
)) as u32;

_mm_storeu_si128($dst as *mut __m128i, str_vec);

if unlikely!(mask > 0) {
let cn = trailing_zeros!(mask) as usize;
nb -= cn;
$dst = $dst.add(cn);
$src = $src.add(cn);
nb -= 1;
let escape = QUOTE_TAB[*($src) as usize];
write_escape!(escape, $dst);
$dst = $dst.add(escape.1 as usize);
$src = $src.add(1);
} else {
nb -= STRIDE;
$dst = $dst.add(STRIDE);
$src = $src.add(STRIDE);
}
}

if nb > 0 {
let mut scratch: [u8; 32] = [b'a'; 32];
let mut str_vec = _mm_loadu_si128(last_stride_src as *const __m128i);
_mm_storeu_si128(scratch.as_mut_ptr() as *mut __m128i, str_vec);

let mut scratch_ptr = scratch.as_mut_ptr().add(16 - nb);
str_vec = _mm_loadu_si128(scratch_ptr as *const __m128i);

let mut mask = _mm_movemask_epi8(_mm_or_si128(
_mm_or_si128(
_mm_cmpeq_epi8(str_vec, blash),
_mm_cmpeq_epi8(str_vec, quote),
),
_mm_cmpeq_epi8(_mm_subs_epu8(str_vec, x20), v0),
)) as u32;

while nb > 0 {
_mm_storeu_si128($dst as *mut __m128i, str_vec);

if unlikely!(mask > 0) {
let cn = trailing_zeros!(mask) as usize;
nb -= cn;
$dst = $dst.add(cn);
scratch_ptr = scratch_ptr.add(cn);
nb -= 1;
mask >>= cn + 1;
let escape = QUOTE_TAB[*(scratch_ptr) as usize];
write_escape!(escape, $dst);
$dst = $dst.add(escape.1 as usize);
scratch_ptr = scratch_ptr.add(1);
str_vec = _mm_loadu_si128(scratch_ptr as *const __m128i);
} else {
$dst = $dst.add(nb);
break;
}
}
}
}
};
}

#[allow(dead_code)]
#[inline(never)]
pub unsafe fn format_escaped_str_impl_sse2_128(
odst: *mut u8,
value_ptr: *const u8,
value_len: usize,
) -> usize {
const STRIDE: usize = 16;

let mut dst = odst;
let mut src = value_ptr;

core::ptr::write(dst, b'"');
dst = dst.add(1);

if value_len < STRIDE {
impl_format_scalar!(dst, src, value_len)
} else {
impl_format_simd_sse2_128!(dst, src, value_len);
}

core::ptr::write(dst, b'"');
dst = dst.add(1);

dst as usize - odst as usize
}

0 comments on commit d18f26c

Please sign in to comment.