From 894e797966b855cceed5e3a49cbfa67929f24441 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Sat, 17 Aug 2024 22:27:11 +0800 Subject: [PATCH] v2 impl --- arrow-array/src/array/byte_view_array.rs | 29 +++++ arrow-string/src/predicate.rs | 145 ++++++++++++++++++++--- 2 files changed, 155 insertions(+), 19 deletions(-) diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs index 42f945838a45..037a7c600275 100644 --- a/arrow-array/src/array/byte_view_array.rs +++ b/arrow-array/src/array/byte_view_array.rs @@ -261,6 +261,21 @@ impl GenericByteViewArray { unsafe { self.value_unchecked(i) } } + /// Returns the inline view data at index `i` + pub unsafe fn prefix_bytes_unchecked(&self, prefix_len: usize, idx: usize) -> &[u8] { + let v = self.views.get_unchecked(idx); + let len = (*v as u32) as usize; + + if prefix_len <= 4 || (prefix_len <= 12 && len <= 12) { + Self::inline_value(v, prefix_len) + } else { + let view = ByteView::from(*v); + let data = self.buffers.get_unchecked(view.buffer_index as usize); + let offset = view.offset as usize; + data.get_unchecked(offset..offset + prefix_len) + } + } + /// Returns the element at index `i` /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array @@ -278,6 +293,20 @@ impl GenericByteViewArray { T::Native::from_bytes_unchecked(b) } + /// Returns the bytes at index `i` + pub unsafe fn bytes_unchecked(&self, idx: usize) -> &[u8] { + let v = self.views.get_unchecked(idx); + let len = *v as u32; + if len <= 12 { + Self::inline_value(v, len as usize) + } else { + let view = ByteView::from(*v); + let data = self.buffers.get_unchecked(view.buffer_index as usize); + let offset = view.offset as usize; + data.get_unchecked(offset..offset + len as usize) + } + } + /// Returns the inline value of the view. /// /// # Safety diff --git a/arrow-string/src/predicate.rs b/arrow-string/src/predicate.rs index ec0c4827830c..0e8ab3530142 100644 --- a/arrow-string/src/predicate.rs +++ b/arrow-string/src/predicate.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::{ArrayAccessor, BooleanArray}; +use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray}; +use arrow_buffer::BooleanBuffer; use arrow_schema::ArrowError; use memchr::memchr2; use memchr::memmem::Finder; @@ -111,24 +112,130 @@ impl<'a> Predicate<'a> { Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| { (haystack.len() == v.len() && haystack == *v) != negate }), - Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| { - haystack.eq_ignore_ascii_case(v) != negate - }), - Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| { - finder.find(haystack.as_bytes()).is_some() != negate - }), - Predicate::StartsWith(v) => BooleanArray::from_unary(array, |haystack| { - starts_with(haystack, v, equals_kernel) != negate - }), - Predicate::IStartsWithAscii(v) => BooleanArray::from_unary(array, |haystack| { - starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate - }), - Predicate::EndsWith(v) => BooleanArray::from_unary(array, |haystack| { - ends_with(haystack, v, equals_kernel) != negate - }), - Predicate::IEndsWithAscii(v) => BooleanArray::from_unary(array, |haystack| { - ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate - }), + Predicate::IEqAscii(v) => { + if let Some(string_view_array) = array.as_any().downcast_ref::() { + let neddle_bytes = v.as_bytes(); + let null_buffer = string_view_array.logical_nulls(); + let boolean_buffer = + BooleanBuffer::collect_bool(string_view_array.len(), |i| { + unsafe { string_view_array.bytes_unchecked(i) } + .eq_ignore_ascii_case(neddle_bytes) + != negate + }); + + BooleanArray::new(boolean_buffer, null_buffer) + } else { + BooleanArray::from_unary(array, |haystack| { + haystack.eq_ignore_ascii_case(v) != negate + }) + } + } + Predicate::Contains(finder) => { + if let Some(string_view_array) = array.as_any().downcast_ref::() { + let null_buffer = string_view_array.logical_nulls(); + let boolean_buffer = + BooleanBuffer::collect_bool(string_view_array.len(), |i| { + finder + .find(unsafe { string_view_array.bytes_unchecked(i) }) + .is_some() + != negate + }); + + BooleanArray::new(boolean_buffer, null_buffer) + } else { + BooleanArray::from_unary(array, |haystack| { + finder.find(haystack.as_bytes()).is_some() != negate + }) + } + } + Predicate::StartsWith(v) => { + if let Some(string_view_array) = array.as_any().downcast_ref::() { + let needle_bytes = v.as_bytes(); + let needle_len = needle_bytes.len(); + let null_buffer = string_view_array.logical_nulls(); + let boolean_buffer = + BooleanBuffer::collect_bool(string_view_array.len(), |i| { + zip( + unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) }, + needle_bytes, + ) + .all(equals_kernel) + }); + + BooleanArray::new(boolean_buffer, null_buffer) + } else { + BooleanArray::from_unary(array, |haystack| { + starts_with(haystack, v, equals_kernel) != negate + }) + } + } + Predicate::IStartsWithAscii(v) => { + if let Some(string_view_array) = array.as_any().downcast_ref::() { + let needle_bytes = v.as_bytes(); + let needle_len = needle_bytes.len(); + let null_buffer = string_view_array.logical_nulls(); + let boolean_buffer = + BooleanBuffer::collect_bool(string_view_array.len(), |i| { + zip( + unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) }, + needle_bytes, + ) + .all(equals_ignore_ascii_case_kernel) + }); + + BooleanArray::new(boolean_buffer, null_buffer) + } else { + BooleanArray::from_unary(array, |haystack| { + starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate + }) + } + } + Predicate::EndsWith(v) => { + if let Some(string_view_array) = array.as_any().downcast_ref::() { + let needle_bytes = v.as_bytes(); + let needle_len = needle_bytes.len(); + let null_buffer = string_view_array.logical_nulls(); + let boolean_buffer = + BooleanBuffer::collect_bool(string_view_array.len(), |i| { + zip( + unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) } + .iter() + .rev(), + needle_bytes.iter().rev(), + ) + .all(equals_kernel) + }); + + BooleanArray::new(boolean_buffer, null_buffer) + } else { + BooleanArray::from_unary(array, |haystack| { + ends_with(haystack, v, equals_kernel) != negate + }) + } + } + Predicate::IEndsWithAscii(v) => { + if let Some(string_view_array) = array.as_any().downcast_ref::() { + let needle_bytes = v.as_bytes(); + let needle_len = needle_bytes.len(); + let null_buffer = string_view_array.logical_nulls(); + let boolean_buffer = + BooleanBuffer::collect_bool(string_view_array.len(), |i| { + zip( + unsafe { string_view_array.prefix_bytes_unchecked(needle_len, i) } + .iter() + .rev(), + needle_bytes.iter().rev(), + ) + .all(equals_ignore_ascii_case_kernel) + }); + + BooleanArray::new(boolean_buffer, null_buffer) + } else { + BooleanArray::from_unary(array, |haystack| { + ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate + }) + } + } Predicate::Regex(v) => { BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate) }