From e5fc0ce9d76814fd362a4ff143d9507c725ae0b3 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 11 Aug 2024 21:22:10 -0400 Subject: [PATCH 1/3] Update LPAD scalar function to support Utf8View --- datafusion/functions/src/unicode/lpad.rs | 677 +++++++++++------- .../sqllogictest/test_files/functions.slt | 26 + .../sqllogictest/test_files/string_view.slt | 20 +- 3 files changed, 463 insertions(+), 260 deletions(-) diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index ce5e0064362b..5caa6acd6745 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -18,16 +18,21 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, + OffsetSizeTrait, StringViewArray, +}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; use unicode_segmentation::UnicodeSegmentation; +use DataType::{LargeUtf8, Utf8, Utf8View}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; + #[derive(Debug)] pub struct LPadFunc { signature: Signature, @@ -45,11 +50,17 @@ impl LPadFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8View, Int64, Utf8View]), + Exact(vec![Utf8View, Int64, Utf8]), + Exact(vec![Utf8View, Int64, LargeUtf8]), Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8View]), Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![LargeUtf8, Int64, Utf8View]), + Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, Int64, LargeUtf8]), ], Volatility::Immutable, @@ -76,300 +87,450 @@ impl ScalarUDFImpl for LPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function lpad"), - } + make_scalar_function(lpad, vec![])(args) } } -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// Extends the string to length 'length' by prepending the characters fill (a space by default). +/// If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } +pub fn lpad(args: &[ArrayRef]) -> Result { + if args.len() <= 1 || args.len() > 3 { + return exec_err!( + "lpad was called with {} arguments. It requires at least 2 and at most 3.", + args.len() + ); + } + + let length_array = as_int64_array(&args[1])?; + + match args[0].data_type() { + Utf8 => match args.len() { + 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i32>( + args[0].as_string::(), + length_array, + None, + ), + 3 => lpad_with_replace::<&GenericStringArray, i32>( + args[0].as_string::(), + length_array, + &args[2], + ), + _ => unreachable!(), + }, + LargeUtf8 => match args.len() { + 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i64>( + args[0].as_string::(), + length_array, + None, + ), + 3 => lpad_with_replace::<&GenericStringArray, i64>( + args[0].as_string::(), + length_array, + &args[2], + ), + _ => unreachable!(), + }, + Utf8View => match args.len() { + 2 => lpad_impl::<&StringViewArray, &GenericStringArray, i32>( + args[0].as_string_view(), + length_array, + None, + ), + 3 => lpad_with_replace::<&StringViewArray, i32>( + args[0].as_string_view(), + length_array, + &args[2], + ), + _ => unreachable!(), + }, + other => { + exec_err!("Unsupported data type {other:?} for function lpad") + } + } +} - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) +fn lpad_with_replace<'a, V, T: OffsetSizeTrait>( + string_array: V, + length_array: &Int64Array, + fill_array: &'a ArrayRef, +) -> Result +where + V: StringArrayType<'a>, +{ + match fill_array.data_type() { + Utf8 => lpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + LargeUtf8 => lpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + Utf8View => lpad_impl::( + string_array, + length_array, + Some(fill_array.as_string_view()), + ), + other => { + exec_err!("Unsupported data type {other:?} for function lpad") + } + } +} + +fn lpad_impl<'a, V, V2, T>( + string_array: V, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + V: StringArrayType<'a>, + V2: StringArrayType<'a>, + T: OffsetSizeTrait, +{ + if fill_array.is_none() { + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + Ok(Some(s)) } } - _ => Ok(None), - }) - .collect::>>()?; + } + _ => Ok(None), + }) + .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } + Ok(Arc::new(result) as ArrayRef) + } else { + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.unwrap().iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Ok(Some(s)) } } - _ => Ok(None), - }) - .collect::>>()?; + } + _ => Ok(None), + }) + .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ), + Ok(Arc::new(result) as ArrayRef) + } +} + +trait StringArrayType<'a>: ArrayAccessor + Sized { + fn iter(&self) -> ArrayIter; +} +impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { + fn iter(&self) -> ArrayIter { + GenericStringArray::::iter(self) + } +} +impl<'a> StringArrayType<'a> for &'a StringViewArray { + fn iter(&self) -> ArrayIter { + StringViewArray::iter(self) } } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + use arrow::array::{Array, LargeStringArray, StringArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::unicode::lpad::LPadFunc; - use crate::utils::test::test_function; + macro_rules! test_lpad { + ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + + ($INPUT:expr, $LENGTH:expr, $REPLACE:expr, $EXPECTED:expr) => { + // utf8, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + // largeutf8, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + // largeutf8, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + // largeutf8, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + // utf8view, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8view, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8view, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + } #[test] fn test_functions() -> Result<()> { - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ], - Ok(Some("")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(5i64)), + Ok(Some(" josé")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Ok(Some(" hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(0i64)), + Ok(Some("")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray + test_lpad!(Some("hi".into()), ScalarValue::Int64(None), Ok(None)); + test_lpad!(None, ScalarValue::Int64(Some(5i64)), Ok(None)); + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some("xy".into()), + Ok(Some("xyxhi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(21i64)), - ColumnarValue::Scalar(ScalarValue::from("abcdef")), - ], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(21i64)), + Some("abcdef".into()), + Ok(Some("abcdefabcdefabcdefahi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(" ")), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some(" ".into()), + Ok(Some(" hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("")), - ], - Ok(Some("hi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some("".into()), + Ok(Some("hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + None, + ScalarValue::Int64(Some(5i64)), + Some("xy".into()), + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(None), + Some("xy".into()), + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + None, + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(10i64)), + Some("xy".into()), + Ok(Some("xyxyxyjosé")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("éñ")), - ], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(10i64)), + Some("éñ".into()), + Ok(Some("éñéñéñjosé")) ); + #[cfg(not(feature = "unicode_expressions"))] - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - internal_err!( + test_lpad!(Some("josé".into()), ScalarValue::Int64(Some(5i64)), internal_err!( "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); + )); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index c3dd791f6ca8..8a4855ea2c05 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -203,6 +203,32 @@ SELECT lpad(NULL, 5, 'xy') ---- NULL +# test largeutf8, utf8view for lpad +query T +SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') +---- +xyxhi + +query T +SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') +---- +xyxhi + +query T +SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) +---- +xyxhi + +query T +SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) +---- +xyxhi + +query T +SELECT lpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') +---- +NULL + query T SELECT reverse('abcde') ---- diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index e7166690580f..dcc6784bf44a 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -634,16 +634,32 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for LPAD -## TODO https://github.com/apache/datafusion/issues/11857 query TT EXPLAIN SELECT LPAD(column1_utf8view, 12, ' ') as c1 FROM test; ---- logical_plan -01)Projection: lpad(CAST(test.column1_utf8view AS Utf8), Int64(12), Utf8(" ")) AS c1 +01)Projection: lpad(test.column1_utf8view, Int64(12), Utf8(" ")) AS c1 02)--TableScan: test projection=[column1_utf8view] +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, column2_large_utf8) as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1 +02)--TableScan: test projection=[column2_large_utf8, column1_utf8view] + +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, column2_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for OCTET_LENGTH ## TODO https://github.com/apache/datafusion/issues/11858 From d5cee12587f9b498bd4f6a1ef9f7f9548fa2b2a5 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 13 Aug 2024 14:22:28 -0400 Subject: [PATCH 2/3] Lpad code improvements and benchmark. --- datafusion/functions/Cargo.toml | 5 + datafusion/functions/benches/pad.rs | 141 +++++++++++++++ datafusion/functions/src/unicode/lpad.rs | 214 +++++++++++------------ 3 files changed, 245 insertions(+), 115 deletions(-) create mode 100644 datafusion/functions/benches/pad.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 9675d03a0161..688563baecfa 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -146,3 +146,8 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "pad" +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs new file mode 100644 index 000000000000..5ff1e2fb860d --- /dev/null +++ b/datafusion/functions/benches/pad.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; +use arrow::datatypes::Int64Type; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode::{lpad, rpad}; +use rand::distributions::{Distribution, Uniform}; +use rand::Rng; +use std::sync::Arc; + +struct Filter { + dist: Dist, +} + +impl Distribution for Filter +where + Dist: Distribution, +{ + fn sample(&self, rng: &mut R) -> T { + self.dist.sample(rng) + } +} + +pub fn create_primitive_array( + size: usize, + null_density: f32, + len: usize, +) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let dist = Filter { + dist: Uniform::new_inclusive::(0, len as i64), + }; + + let mut rng = rand::thread_rng(); + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.sample(&dist)) + } + }) + .collect() +} + +fn create_args( + size: usize, + str_len: usize, + use_string_view: bool, +) -> Vec { + let length_array = Arc::new(create_primitive_array::(size, 0.0, str_len)); + + if !use_string_view { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + let fill_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), + ColumnarValue::Array(fill_array), + ] + } else { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + let fill_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), + ColumnarValue::Array(fill_array), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 2048] { + let mut group = c.benchmark_group("lpad function"); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { + b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + }); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { + b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + }); + + let args = create_args::(size, 32, true); + group.bench_function(BenchmarkId::new("stringview type", size), |b| { + b.iter(|| criterion::black_box(lpad().invoke(&args).unwrap())) + }); + + group.finish(); + + let mut group = c.benchmark_group("rpad function"); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { + b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + }); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { + b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + }); + // + // let args = create_args::(size, 32, true); + // group.bench_function(BenchmarkId::new("stringview type", size), |b| { + // b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + // }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 5caa6acd6745..73e65cf0731d 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -19,8 +19,8 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, - OffsetSizeTrait, StringViewArray, + Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, + GenericStringBuilder, Int64Array, OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use unicode_segmentation::UnicodeSegmentation; @@ -87,14 +87,18 @@ impl ScalarUDFImpl for LPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_scalar_function(lpad, vec![])(args) + match args[0].data_type() { + Utf8 | Utf8View => make_scalar_function(lpad::, vec![])(args), + LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function lpad"), + } } } /// Extends the string to length 'length' by prepending the characters fill (a space by default). /// If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { +pub fn lpad(args: &[ArrayRef]) -> Result { if args.len() <= 1 || args.len() > 3 { return exec_err!( "lpad was called with {} arguments. It requires at least 2 and at most 3.", @@ -104,49 +108,28 @@ pub fn lpad(args: &[ArrayRef]) -> Result { let length_array = as_int64_array(&args[1])?; - match args[0].data_type() { - Utf8 => match args.len() { - 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i32>( - args[0].as_string::(), - length_array, - None, - ), - 3 => lpad_with_replace::<&GenericStringArray, i32>( - args[0].as_string::(), - length_array, - &args[2], - ), - _ => unreachable!(), - }, - LargeUtf8 => match args.len() { - 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i64>( - args[0].as_string::(), - length_array, - None, - ), - 3 => lpad_with_replace::<&GenericStringArray, i64>( - args[0].as_string::(), - length_array, - &args[2], - ), - _ => unreachable!(), - }, - Utf8View => match args.len() { - 2 => lpad_impl::<&StringViewArray, &GenericStringArray, i32>( - args[0].as_string_view(), - length_array, - None, - ), - 3 => lpad_with_replace::<&StringViewArray, i32>( - args[0].as_string_view(), - length_array, - &args[2], - ), - _ => unreachable!(), - }, - other => { - exec_err!("Unsupported data type {other:?} for function lpad") - } + match (args.len(), args[0].data_type()) { + (2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray, T>( + args[0].as_string_view(), + length_array, + None, + ), + (2, Utf8 | LargeUtf8) => lpad_impl::< + &GenericStringArray, + &GenericStringArray, + T, + >(args[0].as_string::(), length_array, None), + (3, Utf8View) => lpad_with_replace::<&StringViewArray, T>( + args[0].as_string_view(), + length_array, + &args[2], + ), + (3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray, T>( + args[0].as_string::(), + length_array, + &args[2], + ), + (_, _) => unreachable!(), } } @@ -159,20 +142,20 @@ where V: StringArrayType<'a>, { match fill_array.data_type() { - Utf8 => lpad_impl::, T>( + Utf8View => lpad_impl::( string_array, length_array, - Some(fill_array.as_string::()), + Some(fill_array.as_string_view()), ), LargeUtf8 => lpad_impl::, T>( string_array, length_array, Some(fill_array.as_string::()), ), - Utf8View => lpad_impl::( + Utf8 => lpad_impl::, T>( string_array, length_array, - Some(fill_array.as_string_view()), + Some(fill_array.as_string::()), ), other => { exec_err!("Unsupported data type {other:?} for function lpad") @@ -190,87 +173,88 @@ where V2: StringArrayType<'a>, T: OffsetSizeTrait, { - if fill_array.is_none() { - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!("lpad requested length {length} too large"); - } + let array = if fill_array.is_none() { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } + for (string, length) in string_array.iter().zip(length_array.iter()) { + if let (Some(string), Some(length)) = (string, length) { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + continue; + } + + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else { + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + builder.append_value(s); + } + } else { + builder.append_null(); + } + } + + builder.finish() } else { - let result = string_array + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + for ((string, length), fill) in string_array .iter() .zip(length_array.iter()) .zip(fill_array.unwrap().iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!("lpad requested length {length} too large"); - } + { + if let (Some(string), Some(length), Some(fill)) = (string, length, fill) { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + continue; + } + + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else if fill_chars.is_empty() { + builder.append_value(string); + } else { + let capacity = length - graphemes.len(); + let mut char_vector = Vec::::with_capacity(capacity); + for l in 0..capacity { + char_vector.push(*fill_chars.get(l % fill_chars.len()).unwrap()); } + let mut s = char_vector.iter().collect::(); + s.push_str(string); + builder.append_value(s); } - _ => Ok(None), - }) - .collect::>>()?; + } else { + builder.append_null(); + } + } - Ok(Arc::new(result) as ArrayRef) - } + builder.finish() + }; + + Ok(Arc::new(array) as ArrayRef) } trait StringArrayType<'a>: ArrayAccessor + Sized { fn iter(&self) -> ArrayIter; } -impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { +impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { fn iter(&self) -> ArrayIter { - GenericStringArray::::iter(self) + GenericStringArray::::iter(self) } } impl<'a> StringArrayType<'a> for &'a StringViewArray { From b445ef94aa557d2313d14e80c8a729eb2095313d Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 14 Aug 2024 10:22:23 -0400 Subject: [PATCH 3/3] Improved use of GenericStringBuilder. --- datafusion/functions/src/unicode/lpad.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 73e65cf0731d..521cdc5d0ff0 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::fmt::Write; use std::sync::Arc; use arrow::array::{ @@ -192,9 +193,9 @@ where if length < graphemes.len() { builder.append_value(graphemes[..length].concat()); } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - builder.append_value(s); + builder.write_str(" ".repeat(length - graphemes.len()).as_str())?; + builder.write_str(string)?; + builder.append_value(""); } } else { builder.append_null(); @@ -229,14 +230,12 @@ where } else if fill_chars.is_empty() { builder.append_value(string); } else { - let capacity = length - graphemes.len(); - let mut char_vector = Vec::::with_capacity(capacity); - for l in 0..capacity { - char_vector.push(*fill_chars.get(l % fill_chars.len()).unwrap()); + for l in 0..length - graphemes.len() { + let c = *fill_chars.get(l % fill_chars.len()).unwrap(); + builder.write_char(c)?; } - let mut s = char_vector.iter().collect::(); - s.push_str(string); - builder.append_value(s); + builder.write_str(string)?; + builder.append_value(""); } } else { builder.append_null();