diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 039f3fb9a6aa..f28c2705816b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -934,9 +934,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.15" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" +checksum = "e9d013ecb737093c0e86b151a7b837993cf9ec6c502946cfb44bedc392421e0b" dependencies = [ "jobserver", "libc", @@ -986,9 +986,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.16" +version = "4.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" +checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" dependencies = [ "clap_builder", "clap_derive", @@ -996,9 +996,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.15" +version = "4.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" +checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" dependencies = [ "anstream", "anstyle", @@ -1182,9 +1182,9 @@ checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" [[package]] name = "dashmap" -version = "6.0.1" +version = "6.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", "crossbeam-utils", @@ -2134,16 +2134,16 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.2" +version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", "hyper 1.4.1", "hyper-util", "rustls 0.23.12", - "rustls-native-certs 0.7.3", + "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -3091,7 +3091,7 @@ dependencies = [ "http-body 1.0.1", "http-body-util", "hyper 1.4.1", - "hyper-rustls 0.27.2", + "hyper-rustls 0.27.3", "hyper-util", "ipnet", "js-sys", @@ -3253,6 +3253,19 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.3", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -3425,9 +3438,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.127" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", "memchr", @@ -3819,9 +3832,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index db218f9127ae..8a70b380669c 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -19,18 +19,18 @@ use std::any::Any; use std::cmp::max; use std::sync::Arc; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, + make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView, + GenericStringArray, OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; - +use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_datafusion_err, exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct SubstrFunc { signature: Signature, @@ -77,7 +77,11 @@ impl ScalarUDFImpl for SubstrFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "substr") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "substr") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -89,29 +93,188 @@ impl ScalarUDFImpl for SubstrFunc { } } +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +/// The implementation uses UTF-8 code points as characters pub fn substr(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => { let string_array = args[0].as_string::(); - calculate_substr::<_, i32>(string_array, &args[1..]) + string_substr::<_, i32>(string_array, &args[1..]) } DataType::LargeUtf8 => { let string_array = args[0].as_string::(); - calculate_substr::<_, i64>(string_array, &args[1..]) + string_substr::<_, i64>(string_array, &args[1..]) } DataType::Utf8View => { let string_array = args[0].as_string_view(); - calculate_substr::<_, i32>(string_array, &args[1..]) + string_view_substr(string_array, &args[1..]) } - other => exec_err!("Unsupported data type {other:?} for function substr"), + other => exec_err!( + "Unsupported data type {other:?} for function substr,\ + expected Utf8View, Utf8 or LargeUtf8." + ), } } -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -/// The implementation uses UTF-8 code points as characters -fn calculate_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result +// Return the exact byte index for [start, end), set count to -1 to ignore count +fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) { + let (mut st, mut ed) = (input.len(), input.len()); + let mut start_counting = false; + let mut cnt = 0; + for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() { + if char_cnt == start { + st = byte_cnt; + if count != -1 { + start_counting = true; + } else { + break; + } + } + if start_counting { + if cnt == count { + ed = byte_cnt; + break; + } + cnt += 1; + } + } + (st, ed) +} + +/// Make a `u128` based on the given substr, start(offset to view.offset), and +/// push into to the given buffers +fn make_and_append_view( + views_buffer: &mut Vec, + null_builder: &mut NullBufferBuilder, + raw: &u128, + substr: &str, + start: u32, +) { + let substr_len = substr.len(); + if substr_len == 0 { + null_builder.append_null(); + views_buffer.push(0); + } else { + let sub_view = if substr_len > 12 { + let view = ByteView::from(*raw); + make_view(substr.as_bytes(), view.buffer_index, view.offset + start) + } else { + // inline value does not need block id or offset + make_view(substr.as_bytes(), 0, 0) + }; + views_buffer.push(sub_view); + null_builder.append_non_null(); + } +} + +// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 +// From for ByteView +fn string_view_substr( + string_view_array: &StringViewArray, + args: &[ArrayRef], +) -> Result { + let mut views_buf = Vec::with_capacity(string_view_array.len()); + let mut null_builder = NullBufferBuilder::new(string_view_array.len()); + + let start_array = as_int64_array(&args[0])?; + + match args.len() { + 1 => { + for (idx, (raw, start)) in string_view_array + .views() + .iter() + .zip(start_array.iter()) + .enumerate() + { + if let Some(start) = start { + let start = (start - 1).max(0) as usize; + + // Safety: + // idx is always smaller or equal to string_view_array.views.len() + unsafe { + let str = string_view_array.value_unchecked(idx); + let (start, end) = get_true_start_end(str, start, -1); + let substr = &str[start..end]; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw, + substr, + start as u32, + ); + } + } else { + null_builder.append_null(); + views_buf.push(0); + } + } + } + 2 => { + let count_array = as_int64_array(&args[1])?; + for (idx, ((raw, start), count)) in string_view_array + .views() + .iter() + .zip(start_array.iter()) + .zip(count_array.iter()) + .enumerate() + { + if let (Some(start), Some(count)) = (start, count) { + let start = (start - 1).max(0) as usize; + if count < 0 { + return exec_err!( + "negative substring length not allowed: substr(, {start}, {count})" + ); + } else { + // Safety: + // idx is always smaller or equal to string_view_array.views.len() + unsafe { + let str = string_view_array.value_unchecked(idx); + let (start, end) = get_true_start_end(str, start, count); + let substr = &str[start..end]; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw, + substr, + start as u32, + ); + } + } + } else { + null_builder.append_null(); + views_buf.push(0); + } + } + } + other => { + return exec_err!( + "substr was called with {other} arguments. It requires 2 or 3." + ) + } + } + + let views_buf = ScalarBuffer::from(views_buf); + let nulls_buf = null_builder.finish(); + + // Safety: + // (1) The blocks of the given views are all provided + // (2) Each of the range `view.offset+start..end` of view in views_buf is within + // the bounds of each of the blocks + unsafe { + let array = StringViewArray::new_unchecked( + views_buf, + string_view_array.data_buffers().to_vec(), + nulls_buf, + ); + Ok(Arc::new(array) as ArrayRef) + } +} + +fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result where V: ArrayAccessor, T: OffsetSizeTrait, @@ -174,8 +337,8 @@ where #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -193,8 +356,8 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( SubstrFunc::new(), @@ -206,8 +369,35 @@ mod tests { ], Ok(Some("alphabet")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "this és longer than 12B" + )))), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some(" é")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "this is longer than 12B" + )))), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some(" is longer than 12B")), + &str, + Utf8View, + StringViewArray ); test_function!( SubstrFunc::new(), @@ -219,8 +409,8 @@ mod tests { ], Ok(Some("ésoj")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( SubstrFunc::new(), @@ -233,8 +423,8 @@ mod tests { ], Ok(Some("ph")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( SubstrFunc::new(), @@ -247,8 +437,8 @@ mod tests { ], Ok(Some("phabet")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( SubstrFunc::new(),