From 02bfefe7f68f8255caf3d20b5bc5e3ca7080dda6 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:48:33 -0700 Subject: [PATCH] Update RPAD scalar function to support Utf8View (#11942) * Update RPAD scalar function to support Utf8View * adding more test coverage * optimize macro --- datafusion/functions/src/unicode/rpad.rs | 233 ++++++++++++------ .../sqllogictest/test_files/functions.slt | 28 +++ .../sqllogictest/test_files/string_view.slt | 22 +- 3 files changed, 203 insertions(+), 80 deletions(-) diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index fc6bf1ffe748..4bcf102c8793 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -20,7 +20,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; @@ -45,11 +47,17 @@ impl RPadFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8View, Int64, Utf8View]), + Exact(vec![Utf8View, Int64, Utf8]), + Exact(vec![Utf8View, Int64, LargeUtf8]), Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8View]), Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![LargeUtf8, Int64, Utf8View]), + Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, Int64, LargeUtf8]), ], Volatility::Immutable, @@ -76,97 +84,168 @@ impl ScalarUDFImpl for RPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(rpad::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(rpad::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function rpad"), + match args.len() { + 2 => match args[0].data_type() { + DataType::Utf8 | DataType::Utf8View => { + make_scalar_function(rpad::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(rpad::, vec![])(args) + } + other => exec_err!("Unsupported data type {other:?} for function rpad"), + }, + 3 => match (args[0].data_type(), args[2].data_type()) { + ( + DataType::Utf8 | DataType::Utf8View, + DataType::Utf8 | DataType::Utf8View, + ) => make_scalar_function(rpad::, vec![])(args), + (DataType::LargeUtf8, DataType::LargeUtf8) => { + make_scalar_function(rpad::, vec![])(args) + } + (DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => { + make_scalar_function(rpad::, vec![])(args) + } + (DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => { + make_scalar_function(rpad::, vec![])(args) + } + (first_type, last_type) => { + exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type) + } + }, + number => { + exec_err!("unsupported arguments number {} for rpad", number) + } } } } -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } - } +macro_rules! process_rpad { + // For the two-argument case + ($string_array:expr, $length_array:expr) => {{ + $string_array + .iter() + .zip($length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!("rpad requested length {} too large", length); } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - let length = if length < 0 { 0 } else { length as usize }; + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - if length < graphemes.len() { Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) } else { let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.push_str(char_vector.iter().collect::().as_str()); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); Ok(Some(s)) } } - _ => Ok(None), - }) - .collect::>>()?; + } + _ => Ok(None), + }) + .collect::>>() + }}; + // For the three-argument case + ($string_array:expr, $length_array:expr, $fill_array:expr) => {{ + $string_array + .iter() + .zip($length_array.iter()) + .zip($fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!("rpad requested length {} too large", length); + } + + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let char_vector: Vec = (0..length - graphemes.len()) + .map(|l| fill_chars[l % fill_chars.len()]) + .collect(); + s.push_str(&char_vector.iter().collect::()); + Ok(Some(s)) + } + } + _ => Ok(None), + }) + .collect::>>() + }}; +} + +/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad( + args: &[ArrayRef], +) -> Result { + match (args.len(), args[0].data_type()) { + (2, DataType::Utf8View) => { + let string_array = as_string_view_array(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = process_rpad!(string_array, length_array)?; + Ok(Arc::new(result) as ArrayRef) + } + (2, _) => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = process_rpad!(string_array, length_array)?; Ok(Arc::new(result) as ArrayRef) } - other => exec_err!( - "rpad was called with {other} arguments. It requires at least 2 and at most 3." + (3, DataType::Utf8View) => { + let string_array = as_string_view_array(&args[0])?; + let length_array = as_int64_array(&args[1])?; + match args[2].data_type() { + DataType::Utf8View => { + let fill_array = as_string_view_array(&args[2])?; + let result = process_rpad!(string_array, length_array, fill_array)?; + Ok(Arc::new(result) as ArrayRef) + } + DataType::Utf8 | DataType::LargeUtf8 => { + let fill_array = as_generic_string_array::(&args[2])?; + let result = process_rpad!(string_array, length_array, fill_array)?; + Ok(Arc::new(result) as ArrayRef) + } + other_type => { + exec_err!("unsupported type for rpad's third operator: {}", other_type) + } + } + } + (3, _) => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + match args[2].data_type() { + DataType::Utf8View => { + let fill_array = as_string_view_array(&args[2])?; + let result = process_rpad!(string_array, length_array, fill_array)?; + Ok(Arc::new(result) as ArrayRef) + } + DataType::Utf8 | DataType::LargeUtf8 => { + let fill_array = as_generic_string_array::(&args[2])?; + let result = process_rpad!(string_array, length_array, fill_array)?; + Ok(Arc::new(result) as ArrayRef) + } + other_type => { + exec_err!("unsupported type for rpad's third operator: {}", other_type) + } + } + } + (other, other_type) => exec_err!( + "rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}", + other, other_type ), } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index f728942b38c3..f2f37a59cc2a 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -294,6 +294,7 @@ SELECT right(NULL, CAST(NULL AS INT)) ---- NULL + query T SELECT rpad('hi', -1, 'xy') ---- @@ -354,6 +355,33 @@ SELECT rpad('xyxhi', 3) ---- xyx +# test for rpad with largeutf8 and utf8View + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) +---- +hixyx + +query T +SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') +---- +NULL + query I SELECT strpos('abc', 'c') ---- diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index a84b0c7b4594..8bc053234e8c 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -926,10 +926,26 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: rpad(__common_expr_1, Int64(1)) AS c1, rpad(__common_expr_1, Int64(2), CAST(test.column2_utf8view AS Utf8)) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: rpad(test.column1_utf8view, Int64(1)) AS c1, rpad(test.column1_utf8view, Int64(2), test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +query TT +EXPLAIN SELECT + RPAD(column1_utf8view, 12, column2_large_utf8) as c1 +FROM test; +---- +logical_plan +01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1 +02)--TableScan: test projection=[column2_large_utf8, column1_utf8view] +query TT +EXPLAIN SELECT + RPAD(column1_utf8view, 12, column2_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for SPLIT_PART ## TODO file ticket