From 58b23cbb4f54fc48db3c5e6f36e264dcf33c844e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Horstmann?= Date: Wed, 10 Apr 2024 22:22:18 +0200 Subject: [PATCH] Improve autovectorization of to_lowercase / to_uppercase functions Refactor the code in the `convert_while_ascii` helper function to make it more suitable for auto-vectorization and also process the full ascii prefix of the string. The generic case conversion logic will only be invoked starting from the first non-ascii character. The runtime on microbenchmarks with ascii-only inputs improves between 1.5x for short and 4x for long inputs on x86_64 and aarch64. The new implementation also encapsulates all unsafe inside the `convert_while_ascii` function. Fixes #123712 --- library/alloc/benches/str.rs | 2 + library/alloc/src/str.rs | 129 +++++++++++------- library/alloc/tests/str.rs | 3 + ...e-123712-str-to-lower-autovectorization.rs | 25 ++++ 4 files changed, 107 insertions(+), 52 deletions(-) create mode 100644 tests/codegen/issues/issue-123712-str-to-lower-autovectorization.rs diff --git a/library/alloc/benches/str.rs b/library/alloc/benches/str.rs index c148ab6b220a5..92a48e0e6b5a6 100644 --- a/library/alloc/benches/str.rs +++ b/library/alloc/benches/str.rs @@ -347,3 +347,5 @@ make_test!(rsplitn_space_char, s, s.rsplitn(10, ' ').count()); make_test!(split_space_str, s, s.split(" ").count()); make_test!(split_ad_str, s, s.split("ad").count()); + +make_test!(to_lowercase, s, s.to_lowercase()); diff --git a/library/alloc/src/str.rs b/library/alloc/src/str.rs index d7fba3ae159c6..a110b09d21424 100644 --- a/library/alloc/src/str.rs +++ b/library/alloc/src/str.rs @@ -9,6 +9,7 @@ use core::borrow::{Borrow, BorrowMut}; use core::iter::FusedIterator; +use core::mem::MaybeUninit; #[stable(feature = "rust1", since = "1.0.0")] pub use core::str::pattern; use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher}; @@ -365,14 +366,9 @@ impl str { without modifying the original"] #[stable(feature = "unicode_case_mapping", since = "1.2.0")] pub fn to_lowercase(&self) -> String { - let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase); + let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase); - // Safety: we know this is a valid char boundary since - // out.len() is only progressed if ascii bytes are found - let rest = unsafe { self.get_unchecked(out.len()..) }; - - // Safety: We have written only valid ASCII to our vec - let mut s = unsafe { String::from_utf8_unchecked(out) }; + let prefix_len = s.len(); for (i, c) in rest.char_indices() { if c == 'Σ' { @@ -381,8 +377,7 @@ impl str { // in `SpecialCasing.txt`, // so hard-code it rather than have a generic "condition" mechanism. // See https://github.com/rust-lang/rust/issues/26035 - let out_len = self.len() - rest.len(); - let sigma_lowercase = map_uppercase_sigma(&self, i + out_len); + let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i); s.push(sigma_lowercase); } else { match conversions::to_lower(c) { @@ -458,14 +453,7 @@ impl str { without modifying the original"] #[stable(feature = "unicode_case_mapping", since = "1.2.0")] pub fn to_uppercase(&self) -> String { - let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase); - - // Safety: we know this is a valid char boundary since - // out.len() is only progressed if ascii bytes are found - let rest = unsafe { self.get_unchecked(out.len()..) }; - - // Safety: We have written only valid ASCII to our vec - let mut s = unsafe { String::from_utf8_unchecked(out) }; + let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase); for c in rest.chars() { match conversions::to_upper(c) { @@ -614,50 +602,87 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box { unsafe { Box::from_raw(Box::into_raw(v) as *mut str) } } -/// Converts the bytes while the bytes are still ascii. +/// Converts leading ascii bytes in `s` by calling the `convert` function. +/// /// For better average performance, this happens in chunks of `2*size_of::()`. -/// Returns a vec with the converted bytes. +/// +/// Returns a tuple of the converted prefix and the remainder starting from +/// the first non-ascii character. +/// +/// This function is only public so that it can be verified in a codegen test, +/// see `issue-123712-str-to-lower-autovectorization.rs`. +#[unstable(feature = "str_internals", issue = "none")] +#[doc(hidden)] #[inline] #[cfg(not(test))] #[cfg(not(no_global_oom_handling))] -fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec { - let mut out = Vec::with_capacity(b.len()); +pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) { + // Process the input in chunks of 16 bytes to enable auto-vectorization. + // Previously the chunk size depended on the size of `usize`, + // but on 32-bit platforms with sse or neon is also the better choice. + // The only downside on other platforms would be a bit more loop-unrolling. + const N: usize = 16; + + let mut slice = s.as_bytes(); + let mut out = Vec::with_capacity(slice.len()); + let mut out_slice = out.spare_capacity_mut(); + + let mut ascii_prefix_len = 0_usize; + let mut is_ascii = [false; N]; + + while slice.len() >= N { + // SAFETY: checked in loop condition + let chunk = unsafe { slice.get_unchecked(..N) }; + // SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets + let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) }; + + for j in 0..N { + is_ascii[j] = chunk[j] <= 127; + } - const USIZE_SIZE: usize = mem::size_of::(); - const MAGIC_UNROLL: usize = 2; - const N: usize = USIZE_SIZE * MAGIC_UNROLL; - const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]); + // Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk + // size gives the best result, specifically a pmovmsk instruction on x86. + // See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not + // currently recognize other similar idioms. + if is_ascii.iter().map(|x| *x as u8).sum::() as usize != N { + break; + } - let mut i = 0; - unsafe { - while i + N <= b.len() { - // Safety: we have checks the sizes `b` and `out` to know that our - let in_chunk = b.get_unchecked(i..i + N); - let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N); - - let mut bits = 0; - for j in 0..MAGIC_UNROLL { - // read the bytes 1 usize at a time (unaligned since we haven't checked the alignment) - // safety: in_chunk is valid bytes in the range - bits |= in_chunk.as_ptr().cast::().add(j).read_unaligned(); - } - // if our chunks aren't ascii, then return only the prior bytes as init - if bits & NONASCII_MASK != 0 { - break; - } + for j in 0..N { + out_chunk[j] = MaybeUninit::new(convert(&chunk[j])); + } - // perform the case conversions on N bytes (gets heavily autovec'd) - for j in 0..N { - // safety: in_chunk and out_chunk is valid bytes in the range - let out = out_chunk.get_unchecked_mut(j); - out.write(convert(in_chunk.get_unchecked(j))); - } + ascii_prefix_len += N; + slice = unsafe { slice.get_unchecked(N..) }; + out_slice = unsafe { out_slice.get_unchecked_mut(N..) }; + } - // mark these bytes as initialised - i += N; + // handle the remainder as individual bytes + while slice.len() > 0 { + let byte = slice[0]; + if byte > 127 { + break; + } + // SAFETY: out_slice has at least same length as input slice + unsafe { + *out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte)); } - out.set_len(i); + ascii_prefix_len += 1; + slice = unsafe { slice.get_unchecked(1..) }; + out_slice = unsafe { out_slice.get_unchecked_mut(1..) }; } - out + unsafe { + // SAFETY: ascii_prefix_len bytes have been initialized above + out.set_len(ascii_prefix_len); + + // SAFETY: We have written only valid ascii to the output vec + let ascii_string = String::from_utf8_unchecked(out); + + // SAFETY: we know this is a valid char boundary + // since we only skipped over leading ascii bytes + let rest = core::str::from_utf8_unchecked(slice); + + (ascii_string, rest) + } } diff --git a/library/alloc/tests/str.rs b/library/alloc/tests/str.rs index a6b1fe5b97945..536581f6a505b 100644 --- a/library/alloc/tests/str.rs +++ b/library/alloc/tests/str.rs @@ -1850,7 +1850,10 @@ fn to_lowercase() { assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α"); // https://github.com/rust-lang/rust/issues/124714 + // input lengths around the boundary of the chunk size used by the ascii prefix optimization + assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς"); assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς"); + assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς"); // a really long string that has it's lowercase form // even longer. this tests that implementations don't assume diff --git a/tests/codegen/issues/issue-123712-str-to-lower-autovectorization.rs b/tests/codegen/issues/issue-123712-str-to-lower-autovectorization.rs new file mode 100644 index 0000000000000..c7f7b9f4f56b2 --- /dev/null +++ b/tests/codegen/issues/issue-123712-str-to-lower-autovectorization.rs @@ -0,0 +1,25 @@ +//@ only-x86_64 +// +//@ needs-llvm-components: x86 +//@ compile-flags: --target=x86_64-unknown-linux-gnu -Copt-level=3 +#![crate_type = "lib"] +#![no_std] +#![feature(str_internals)] + +extern crate alloc; + +/// Ensure that the ascii-prefix loop for `str::to_lowercase` and `str::to_uppercase` uses vector +/// instructions. +/// +/// The llvm ir should be the same for all targets that support some form of simd. Only targets +/// without any simd instructions would see scalarized ir. +/// Unfortunately, there is no `only-simd` directive to only run this test on only such platforms, +/// and using test revisions would still require the core libraries for all platforms. +// CHECK-LABEL: @lower_while_ascii +// CHECK: [[A:%[0-9]]] = load <16 x i8> +// CHECK-NEXT: [[B:%[0-9]]] = icmp slt <16 x i8> [[A]], zeroinitializer +// CHECK-NEXT: [[C:%[0-9]]] = bitcast <16 x i1> [[B]] to i16 +#[no_mangle] +pub fn lower_while_ascii(s: &str) -> (alloc::string::String, &str) { + alloc::str::convert_while_ascii(s, u8::to_ascii_lowercase) +}