From 8d6c7cc10d351ced4ccc2fe335241a26e8ac848d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 26 Jul 2024 11:03:21 +0200 Subject: [PATCH] add tests and refactor slightly --- arrow-string/src/like.rs | 15 +++- arrow-string/src/predicate.rs | 139 ++++++++++++++++++++++++---------- 2 files changed, 110 insertions(+), 44 deletions(-) diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 8db2e0622a87..2cc2bc4cb6c2 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -273,9 +273,18 @@ fn op_binary<'a>( match op { Op::Like(neg) => binary_predicate(l, r, neg, Predicate::like), Op::ILike(neg) => binary_predicate(l, r, neg, |s| Predicate::ilike(s, false)), - Op::Contains => Ok(l.zip(r).map(|(l, r)| Some(l?.contains(r?))).collect()), - Op::StartsWith => Ok(l.zip(r).map(|(l, r)| Some(crate::predicate::starts_with(l?, r?))).collect()), - Op::EndsWith => Ok(l.zip(r).map(|(l, r)| Some(crate::predicate::ends_with(l?, r?))).collect()), + Op::Contains => Ok(l + .zip(r) + .map(|(l, r)| Some(Predicate::Contains(r?).evaluate(l?))) + .collect()), + Op::StartsWith => Ok(l + .zip(r) + .map(|(l, r)| Some(Predicate::StartsWith(r?).evaluate(l?))) + .collect()), + Op::EndsWith => Ok(l + .zip(r) + .map(|(l, r)| Some(Predicate::EndsWith(r?).evaluate(l?))) + .collect()), } } diff --git a/arrow-string/src/predicate.rs b/arrow-string/src/predicate.rs index 415e0cbb8877..de682c11129a 100644 --- a/arrow-string/src/predicate.rs +++ b/arrow-string/src/predicate.rs @@ -83,10 +83,10 @@ impl<'a> Predicate<'a> { Predicate::Eq(v) => *v == haystack, Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v), Predicate::Contains(v) => haystack.contains(v), - Predicate::StartsWith(v) => starts_with(haystack, v), - Predicate::IStartsWithAscii(v) => starts_with_ignore_ascii_case(haystack, v), - Predicate::EndsWith(v) => ends_with(haystack, v), - Predicate::IEndsWithAscii(v) => ends_with_ignore_ascii_case(haystack, v), + Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel), + Predicate::IStartsWithAscii(v) => starts_with(haystack, v, equals_ignore_ascii_case_kernel), + Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel), + Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel), Predicate::Regex(v) => v.is_match(haystack), } } @@ -109,17 +109,17 @@ impl<'a> Predicate<'a> { Predicate::Contains(v) => { BooleanArray::from_unary(array, |haystack| haystack.contains(v) != negate) } - Predicate::StartsWith(v) => { - BooleanArray::from_unary(array, |haystack| starts_with(haystack, v) != negate) - } + Predicate::StartsWith(v) => BooleanArray::from_unary(array, |haystack| { + starts_with(haystack, v, equals_kernel) != negate + }), Predicate::IStartsWithAscii(v) => BooleanArray::from_unary(array, |haystack| { - starts_with_ignore_ascii_case(haystack, v) != negate + starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate + }), + Predicate::EndsWith(v) => BooleanArray::from_unary(array, |haystack| { + ends_with(haystack, v, equals_kernel) != negate }), - Predicate::EndsWith(v) => { - BooleanArray::from_unary(array, |haystack| ends_with(haystack, v) != negate) - } Predicate::IEndsWithAscii(v) => BooleanArray::from_unary(array, |haystack| { - ends_with_ignore_ascii_case(haystack, v) != negate + ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate }), Predicate::Regex(v) => { BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate) @@ -128,43 +128,23 @@ impl<'a> Predicate<'a> { } } -#[inline] -pub(crate) fn starts_with(haystack: &str, needle: &str) -> bool { - if needle.len() > haystack.len() { - false - } else { - std::iter::zip(haystack.as_bytes(), needle.as_bytes()).all(equals_kernel) - } -} - -#[inline] -pub(crate) fn starts_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool { - debug_assert!(needle.is_ascii(), "needle must be ascii"); - - if needle.len() > haystack.len() { - false - } else { - std::iter::zip(haystack.as_bytes().iter(), needle.as_bytes().iter()).all(i_equals_kernel) - } -} - -#[inline] -pub(crate) fn ends_with(haystack: &str, needle: &str) -> bool { +fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool { if needle.len() > haystack.len() { false } else { - std::iter::zip(haystack.as_bytes().iter().rev(), needle.as_bytes().iter().rev()).all(equals_kernel) + std::iter::zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel) } } -#[inline] -pub(crate) fn ends_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool { - debug_assert!(needle.is_ascii(), "needle must be ascii"); - +fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool { if needle.len() > haystack.len() { false } else { - std::iter::zip(haystack.as_bytes().iter().rev(), needle.as_bytes().iter().rev()).all(i_equals_kernel) + std::iter::zip( + haystack.as_bytes().iter().rev(), + needle.as_bytes().iter().rev(), + ) + .all(byte_eq_kernel) } } @@ -172,7 +152,7 @@ fn equals_kernel((n, h): (&u8, &u8)) -> bool { n == h } -fn i_equals_kernel((n, h): (&u8, &u8)) -> bool { +fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool { n.to_ascii_lowercase() == h.to_ascii_lowercase() } @@ -265,4 +245,81 @@ mod tests { let r = regex_like(a_eq, false).unwrap(); assert_eq!(r.to_string(), expected); } + + #[test] + fn test_starts_with() { + assert!(Predicate::StartsWith("hay").evaluate("haystack")); + assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack")); + assert!(Predicate::StartsWith("haystack").evaluate("haystack")); + assert!(Predicate::StartsWith("ha").evaluate("haystack")); + assert!(Predicate::StartsWith("h").evaluate("haystack")); + assert!(Predicate::StartsWith("").evaluate("haystack")); + + assert!(!Predicate::StartsWith("stack").evaluate("haystack")); + assert!(!Predicate::StartsWith("haystacks").evaluate("haystack")); + assert!(!Predicate::StartsWith("HAY").evaluate("haystack")); + assert!(!Predicate::StartsWith("h£ay").evaluate("haystack")); + assert!(!Predicate::StartsWith("hay").evaluate("h£aystack")); + } + + #[test] + fn test_ends_with() { + assert!(Predicate::EndsWith("stack").evaluate("haystack")); + assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack")); + assert!(Predicate::EndsWith("haystack").evaluate("haystack")); + assert!(Predicate::EndsWith("ck").evaluate("haystack")); + assert!(Predicate::EndsWith("k").evaluate("haystack")); + assert!(Predicate::EndsWith("").evaluate("haystack")); + + assert!(!Predicate::EndsWith("hay").evaluate("haystack")); + assert!(!Predicate::EndsWith("STACK").evaluate("haystack")); + assert!(!Predicate::EndsWith("haystacks").evaluate("haystack")); + assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack")); + assert!(!Predicate::EndsWith("st£ack").evaluate("haystack")); + assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack")); + } + + #[test] + fn test_istarts_with() { + assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack")); + assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK")); + assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack")); + assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack")); + assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk")); + assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk")); + assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk")); + assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk")); + assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk")); + + assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack")); + assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack")); + assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack")); + assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack")); + } + + #[test] + fn test_iends_with() { + assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK")); + assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK")); + assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK")); + assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk")); + assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk")); + assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk")); + assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK")); + assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack")); + assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK")); + assert!(Predicate::IEndsWithAscii("").evaluate("haystack")); + + assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack")); + assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK")); + assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack")); + assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k")); + assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack")); + } }