Skip to content

Commit

Permalink
Add ARM NEON check implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Lynnesbian committed Sep 17, 2024
1 parent 0feb5d5 commit 6d33a11
Showing 1 changed file with 166 additions and 8 deletions.
174 changes: 166 additions & 8 deletions src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// avx2 decode modified from https://github.com/zbjornson/fast-hex/blob/master/src/hex.cc

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -114,7 +116,16 @@ pub fn hex_check_with_case(src: &[u8], check_case: CheckCase) -> bool {
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg(target_arch = "aarch64")]
{
match crate::vectorization_support() {
crate::Vectorization::Neon => unsafe { hex_check_neon_with_case(src, check_case) },
crate::Vectorization::None => hex_check_fallback_with_case(src, check_case),
_ => unreachable!(),
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
hex_check_fallback_with_case(src, check_case)
}

Expand Down Expand Up @@ -210,6 +221,72 @@ pub unsafe fn hex_check_sse_with_case(mut src: &[u8], check_case: CheckCase) ->
hex_check_fallback_with_case(src, check_case)
}

#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
pub unsafe fn hex_check_neon(src: &[u8]) -> bool {
hex_check_neon_with_case(src, CheckCase::None)
}

#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
pub unsafe fn hex_check_neon_with_case(mut src: &[u8], check_case: CheckCase) -> bool {
let ascii_zero = vdupq_n_u8(b'0' - 1);
let ascii_nine = vdupq_n_u8(b'9' + 1);
let ascii_ua = vdupq_n_u8(b'A' - 1);
let ascii_uf = vdupq_n_u8(b'F' + 1);
let ascii_la = vdupq_n_u8(b'a' - 1);
let ascii_lf = vdupq_n_u8(b'f' + 1);

while src.len() >= 16 {
let unchecked = vld1q_u8(src.as_ptr() as *const _);

let gt0 = vcgtq_u8(unchecked, ascii_zero);
let lt9 = vcltq_u8(unchecked, ascii_nine);
let valid_digit = vandq_u8(gt0, lt9);

let (valid_la_lf, valid_ua_uf) = match check_case {
CheckCase::None => {
let gtua = vcgtq_u8(unchecked, ascii_ua);
let ltuf = vcltq_u8(unchecked, ascii_uf);

let gtla = vcgtq_u8(unchecked, ascii_la);
let ltlf = vcltq_u8(unchecked, ascii_lf);

(Some(vandq_u8(gtla, ltlf)), Some(vandq_u8(gtua, ltuf)))
}
CheckCase::Lower => {
let gtla = vcgtq_u8(unchecked, ascii_la);
let ltlf = vcltq_u8(unchecked, ascii_lf);

(Some(vandq_u8(gtla, ltlf)), None)
}
CheckCase::Upper => {
let gtua = vcgtq_u8(unchecked, ascii_ua);
let ltuf = vcltq_u8(unchecked, ascii_uf);

(None, Some(vandq_u8(gtua, ltuf)))
}
};

let valid_letter = match (valid_la_lf, valid_ua_uf) {
(Some(valid_lower), Some(valid_upper)) => vorrq_u8(valid_lower, valid_upper),
(Some(valid_lower), None) => valid_lower,
(None, Some(valid_upper)) => valid_upper,
_ => unreachable!(),
};

let ret = vminvq_u8(vorrq_u8(valid_digit, valid_letter));

if ret == 0 {
return false;
}

src = &src[16..];
}

hex_check_fallback_with_case(src, check_case)
}

/// Hex decode src into dst.
/// The length of src must be even, and it's allowed to decode a zero length src.
/// The length of dst must be at least src.len() / 2.
Expand Down Expand Up @@ -247,7 +324,7 @@ pub fn hex_decode_unchecked(src: &[u8], dst: &mut [u8]) {
crate::Vectorization::AVX2 => unsafe { hex_decode_avx2(src, dst) },
crate::Vectorization::None | crate::Vectorization::SSE41 => {
hex_decode_fallback(src, dst)
},
}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -455,15 +532,25 @@ mod tests {
}
}

#[cfg(all(test, any(target_arch = "x86", target_arch = "x86_64")))]
mod test_sse {
#[cfg(all(
test,
any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")
))]
mod test_simd {
use crate::decode::{
hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_sse,
hex_check_sse_with_case, hex_check_with_case, hex_decode, hex_decode_unchecked,
hex_decode_with_case, CheckCase,
hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_with_case,
hex_decode, hex_decode_unchecked, hex_decode_with_case, CheckCase,
};
#[cfg(target_arch = "aarch64")]
use crate::decode::{hex_check_neon, hex_check_neon_with_case};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use crate::decode::{hex_check_sse, hex_check_sse_with_case};
#[cfg(target_arch = "aarch64")]
use std::arch::is_aarch64_feature_detected;

use proptest::proptest;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn _test_check_sse_with_case(s: &String, check_case: CheckCase, expect_result: bool) {
if is_x86_feature_detected!("sse4.1") {
assert_eq!(
Expand All @@ -473,12 +560,14 @@ mod test_sse {
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn _test_check_sse_true(s: &String) {
if is_x86_feature_detected!("sse4.1") {
assert!(unsafe { hex_check_sse(s.as_bytes()) });
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
proptest! {
#[test]
fn test_check_sse_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") {
Expand All @@ -505,12 +594,13 @@ mod test_sse {
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn _test_check_sse_false(s: &String) {
if is_x86_feature_detected!("sse4.1") {
assert!(!unsafe { hex_check_sse(s.as_bytes()) });
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
proptest! {
#[test]
fn test_check_sse_false(ref s in ".{16}[^0-9a-fA-F]+") {
Expand All @@ -521,6 +611,67 @@ mod test_sse {
}
}

#[cfg(target_arch = "aarch64")]
fn _test_check_neon_with_case(s: &String, check_case: CheckCase, expect_result: bool) {
if is_aarch64_feature_detected!("neon") {
assert_eq!(
unsafe { hex_check_neon_with_case(s.as_bytes(), check_case) },
expect_result
)
}
}

#[cfg(target_arch = "aarch64")]
fn _test_check_neon_true(s: &String) {
if is_aarch64_feature_detected!("neon") {
assert!(unsafe { hex_check_neon(s.as_bytes()) });
}
}

#[cfg(target_arch = "aarch64")]
proptest! {
#[test]
fn test_check_neon_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") {
_test_check_neon_true(s);
_test_check_neon_with_case(s, CheckCase::None, true);
match (s.contains(char::is_lowercase), s.contains(char::is_uppercase)){
(true, true) => {
_test_check_neon_with_case(s, CheckCase::Lower, false);
_test_check_neon_with_case(s, CheckCase::Upper, false);
},
(true, false) => {
_test_check_neon_with_case(s, CheckCase::Lower, true);
_test_check_neon_with_case(s, CheckCase::Upper, false);
},
(false, true) => {
_test_check_neon_with_case(s, CheckCase::Lower, false);
_test_check_neon_with_case(s, CheckCase::Upper, true);
},
(false, false) => {
_test_check_neon_with_case(s, CheckCase::Lower, true);
_test_check_neon_with_case(s, CheckCase::Upper, true);
}
}
}
}

#[cfg(target_arch = "aarch64")]
fn _test_check_neon_false(s: &String) {
if is_aarch64_feature_detected!("neon") {
assert!(!unsafe { hex_check_neon(s.as_bytes()) });
}
}
#[cfg(target_arch = "aarch64")]
proptest! {
#[test]
fn test_check_neon_false(ref s in ".{16}[^0-9a-fA-F]+") {
_test_check_neon_false(s);
_test_check_neon_with_case(s, CheckCase::None, false);
_test_check_neon_with_case(s, CheckCase::Lower, false);
_test_check_neon_with_case(s, CheckCase::Upper, false);
}
}

#[test]
fn test_decode_zero_length_src_should_not_be_ok() {
let src = b"";
Expand All @@ -536,11 +687,18 @@ mod test_sse {
assert!(hex_check_fallback(src));
assert!(hex_check_fallback_with_case(src, CheckCase::None));

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("sse4.1") {
assert!(unsafe { hex_check_sse_with_case(src, CheckCase::None) });
assert!(unsafe { hex_check_sse(src) });
}

#[cfg(target_arch = "aarch64")]
if is_aarch64_feature_detected!("neon") {
assert!(unsafe { hex_check_neon_with_case(src, CheckCase::None) });
assert!(unsafe { hex_check_neon(src) });
}

// this function have no return value, so we just execute it and expect no panic
hex_decode_unchecked(src, &mut dst);
}
Expand Down

0 comments on commit 6d33a11

Please sign in to comment.