Skip to content

Commit

Permalink
Update RPAD scalar function to support Utf8View (#11942)
Browse files Browse the repository at this point in the history
* Update RPAD scalar function to support Utf8View

* adding more test coverage

* optimize macro
  • Loading branch information
Lordworms authored Aug 14, 2024
1 parent f98f8a9 commit 02bfefe
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 80 deletions.
233 changes: 156 additions & 77 deletions datafusion/functions/src/unicode/rpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use unicode_segmentation::UnicodeSegmentation;

use crate::utils::{make_scalar_function, utf8_to_str_type};
Expand All @@ -45,11 +47,17 @@ impl RPadFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Int64]),
Exact(vec![Utf8View, Int64, Utf8View]),
Exact(vec![Utf8View, Int64, Utf8]),
Exact(vec![Utf8View, Int64, LargeUtf8]),
Exact(vec![Utf8, Int64]),
Exact(vec![LargeUtf8, Int64]),
Exact(vec![Utf8, Int64, Utf8View]),
Exact(vec![Utf8, Int64, Utf8]),
Exact(vec![LargeUtf8, Int64, Utf8]),
Exact(vec![Utf8, Int64, LargeUtf8]),
Exact(vec![LargeUtf8, Int64]),
Exact(vec![LargeUtf8, Int64, Utf8View]),
Exact(vec![LargeUtf8, Int64, Utf8]),
Exact(vec![LargeUtf8, Int64, LargeUtf8]),
],
Volatility::Immutable,
Expand All @@ -76,97 +84,168 @@ impl ScalarUDFImpl for RPadFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(rpad::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(rpad::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function rpad"),
match args.len() {
2 => match args[0].data_type() {
DataType::Utf8 | DataType::Utf8View => {
make_scalar_function(rpad::<i32, i32>, vec![])(args)
}
DataType::LargeUtf8 => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
other => exec_err!("Unsupported data type {other:?} for function rpad"),
},
3 => match (args[0].data_type(), args[2].data_type()) {
(
DataType::Utf8 | DataType::Utf8View,
DataType::Utf8 | DataType::Utf8View,
) => make_scalar_function(rpad::<i32, i32>, vec![])(args),
(DataType::LargeUtf8, DataType::LargeUtf8) => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
(DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => {
make_scalar_function(rpad::<i64, i32>, vec![])(args)
}
(DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => {
make_scalar_function(rpad::<i32, i64>, vec![])(args)
}
(first_type, last_type) => {
exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type)
}
},
number => {
exec_err!("unsupported arguments number {} for rpad", number)
}
}
}
}

/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
/// rpad('hi', 5, 'xy') = 'hixyx'
pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let length_array = as_int64_array(&args[1])?;

let result = string_array
.iter()
.zip(length_array.iter())
.map(|(string, length)| match (string, length) {
(Some(string), Some(length)) => {
if length > i32::MAX as i64 {
return exec_err!(
"rpad requested length {length} too large"
);
}

let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
Ok(Some("".to_string()))
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
if length < graphemes.len() {
Ok(Some(graphemes[..length].concat()))
} else {
let mut s = string.to_string();
s.push_str(" ".repeat(length - graphemes.len()).as_str());
Ok(Some(s))
}
}
macro_rules! process_rpad {
// For the two-argument case
($string_array:expr, $length_array:expr) => {{
$string_array
.iter()
.zip($length_array.iter())
.map(|(string, length)| match (string, length) {
(Some(string), Some(length)) => {
if length > i32::MAX as i64 {
return exec_err!("rpad requested length {} too large", length);
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let length_array = as_int64_array(&args[1])?;
let fill_array = as_generic_string_array::<T>(&args[2])?;

let result = string_array
.iter()
.zip(length_array.iter())
.zip(fill_array.iter())
.map(|((string, length), fill)| match (string, length, fill) {
(Some(string), Some(length), Some(fill)) => {
if length > i32::MAX as i64 {
return exec_err!(
"rpad requested length {length} too large"
);
}

let length = if length < 0 { 0 } else { length as usize };
let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
Ok(Some("".to_string()))
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
let fill_chars = fill.chars().collect::<Vec<char>>();

if length < graphemes.len() {
Ok(Some(graphemes[..length].concat()))
} else if fill_chars.is_empty() {
Ok(Some(string.to_string()))
} else {
let mut s = string.to_string();
let mut char_vector =
Vec::<char>::with_capacity(length - graphemes.len());
for l in 0..length - graphemes.len() {
char_vector
.push(*fill_chars.get(l % fill_chars.len()).unwrap());
}
s.push_str(char_vector.iter().collect::<String>().as_str());
s.push_str(" ".repeat(length - graphemes.len()).as_str());
Ok(Some(s))
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<T>>>()?;
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<StringArrayLen>>>()
}};

// For the three-argument case
($string_array:expr, $length_array:expr, $fill_array:expr) => {{
$string_array
.iter()
.zip($length_array.iter())
.zip($fill_array.iter())
.map(|((string, length), fill)| match (string, length, fill) {
(Some(string), Some(length), Some(fill)) => {
if length > i32::MAX as i64 {
return exec_err!("rpad requested length {} too large", length);
}

let length = if length < 0 { 0 } else { length as usize };
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
let fill_chars = fill.chars().collect::<Vec<char>>();

if length < graphemes.len() {
Ok(Some(graphemes[..length].concat()))
} else if fill_chars.is_empty() {
Ok(Some(string.to_string()))
} else {
let mut s = string.to_string();
let char_vector: Vec<char> = (0..length - graphemes.len())
.map(|l| fill_chars[l % fill_chars.len()])
.collect();
s.push_str(&char_vector.iter().collect::<String>());
Ok(Some(s))
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<StringArrayLen>>>()
}};
}

/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
/// rpad('hi', 5, 'xy') = 'hixyx'
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
match (args.len(), args[0].data_type()) {
(2, DataType::Utf8View) => {
let string_array = as_string_view_array(&args[0])?;
let length_array = as_int64_array(&args[1])?;

let result = process_rpad!(string_array, length_array)?;
Ok(Arc::new(result) as ArrayRef)
}
(2, _) => {
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
let length_array = as_int64_array(&args[1])?;

let result = process_rpad!(string_array, length_array)?;
Ok(Arc::new(result) as ArrayRef)
}
other => exec_err!(
"rpad was called with {other} arguments. It requires at least 2 and at most 3."
(3, DataType::Utf8View) => {
let string_array = as_string_view_array(&args[0])?;
let length_array = as_int64_array(&args[1])?;
match args[2].data_type() {
DataType::Utf8View => {
let fill_array = as_string_view_array(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 | DataType::LargeUtf8 => {
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
other_type => {
exec_err!("unsupported type for rpad's third operator: {}", other_type)
}
}
}
(3, _) => {
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
let length_array = as_int64_array(&args[1])?;
match args[2].data_type() {
DataType::Utf8View => {
let fill_array = as_string_view_array(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 | DataType::LargeUtf8 => {
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
other_type => {
exec_err!("unsupported type for rpad's third operator: {}", other_type)
}
}
}
(other, other_type) => exec_err!(
"rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}",
other, other_type
),
}
}
Expand Down
28 changes: 28 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ SELECT right(NULL, CAST(NULL AS INT))
----
NULL


query T
SELECT rpad('hi', -1, 'xy')
----
Expand Down Expand Up @@ -354,6 +355,33 @@ SELECT rpad('xyxhi', 3)
----
xyx

# test for rpad with largeutf8 and utf8View

query T
SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy')
----
hixyx

query T
SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy')
----
hixyx

query T
SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8'))
----
hixyx

query T
SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View'))
----
hixyx

query T
SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy')
----
NULL

query I
SELECT strpos('abc', 'c')
----
Expand Down
22 changes: 19 additions & 3 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,26 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: rpad(__common_expr_1, Int64(1)) AS c1, rpad(__common_expr_1, Int64(2), CAST(test.column2_utf8view AS Utf8)) AS c2
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view
03)----TableScan: test projection=[column1_utf8view, column2_utf8view]
01)Projection: rpad(test.column1_utf8view, Int64(1)) AS c1, rpad(test.column1_utf8view, Int64(2), test.column2_utf8view) AS c2
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

query TT
EXPLAIN SELECT
RPAD(column1_utf8view, 12, column2_large_utf8) as c1
FROM test;
----
logical_plan
01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1
02)--TableScan: test projection=[column2_large_utf8, column1_utf8view]

query TT
EXPLAIN SELECT
RPAD(column1_utf8view, 12, column2_utf8view) as c1
FROM test;
----
logical_plan
01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for SPLIT_PART
## TODO file ticket
Expand Down

0 comments on commit 02bfefe

Please sign in to comment.