Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement native support StringView for overlay #11968

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 118 additions & 64 deletions datafusion/functions/src/string/overlay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
Expand All @@ -46,8 +48,10 @@ impl OverlayFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View, Int64, Int64]),
Exact(vec![Utf8, Utf8, Int64, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
Exact(vec![Utf8View, Utf8View, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
Expand Down Expand Up @@ -76,54 +80,107 @@ impl ScalarUDFImpl for OverlayFunc {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(overlay::<i32>, vec![])(args),
DataType::Utf8View | DataType::Utf8 => {
make_scalar_function(overlay::<i32>, vec![])(args)
}
DataType::LargeUtf8 => make_scalar_function(overlay::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function overlay"),
}
}
}

macro_rules! process_overlay {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have some sort of trait in arrow-rs that allowed us to write this as a generic function

It actually does have
https://github.com/apache/arrow-rs/blob/2461a16c19ee5032531b1c05dd7e7192bc842e0f/arrow-string/src/like.rs#L158-L161

But that is not public

@XiangpengHao do you know of anything that is pub?

We could also implement such a trait for DataFusion's convenience, and then propose upstreaming it 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just behind -- it seems that @Omega359 did exactly this in StringArrayType #11941

trait StringArrayType<'a>: ArrayAccessor<Item = &'a str> + Sized {
fn iter(&self) -> ArrayIter<Self>;
}
impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray<T> {
fn iter(&self) -> ArrayIter<Self> {
GenericStringArray::<T>::iter(self)
}
}
impl<'a> StringArrayType<'a> for &'a StringViewArray {
fn iter(&self) -> ArrayIter<Self> {
StringViewArray::iter(self)
}
}

Maybe we can start to pull that trait into its own module and start reusing it across the string functions 🤔

Also, there is the ArrayAccessor pattern used elegantly by @devanbenz in #11967 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current style is inspired by rpad in #11942 . I'll be rewriting it with ArrayAccessor, which I've used in other PRs, and it's a much more elegant way.

// For the three-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
}};

// For the four-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.zip($len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
}};
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let use_string_view = args[0].data_type() == &DataType::Utf8View;
if use_string_view {
string_view_overlay::<T>(args)
} else {
string_overlay::<T>(args)
}
}

pub fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
Expand All @@ -132,37 +189,34 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.zip(pos_num.iter())
.zip(len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()?;
let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
}
}
}

pub fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: overlay(CAST(test.column1_utf8view AS Utf8), Utf8("foo"), Int64(2)) AS c1
01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1
02)--TableScan: test projection=[column1_utf8view]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, can we please also add actually running these queries to the tests

like

query 
SELECT OVERLAY(column1_utf8view PLACING 'foo' FROM 2 ) as c1

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! 907e27e

## Ensure no casts for REGEXP_LIKE
Expand Down