Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support FixedSizeList Type Coercion #9108

Merged
merged 21 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
};

use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
use datafusion_common::{exec_err, plan_err, DataFusionError, Result};

use strum::IntoEnumIterator;
use strum_macros::EnumIter;
Expand Down Expand Up @@ -543,10 +543,11 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) if matches!(field.data_type(), DataType::List(_)) => get_base_type(field.data_type()),
DataType::List(field) | DataType::FixedSizeList(field, _) if matches!(field.data_type(), DataType::List(_)|DataType::FixedSizeList(_,_ )) => get_base_type(field.data_type()),
DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()),
DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()),
_ => internal_err!("Not reachable, data_type should be List or LargeList"),
DataType::FixedSizeList(field,_ ) => Ok(DataType::List(field.clone())),
_ => exec_err!("Not reachable, data_type should be List, LargeList or FixedSizeList"),
}
}

Expand Down Expand Up @@ -929,18 +930,18 @@ impl BuiltinScalarFunction {
// 0 or more arguments of arbitrary type
Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopFront => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()),
Weijun-H marked this conversation as resolved.
Show resolved Hide resolved
BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayElement => {
Signature::array_and_index(self.volatility())
}
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny => {
Signature::any(2, self.volatility())
}
Expand All @@ -950,8 +951,8 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::array_and_element_and_optional_index(self.volatility())
}
Expand Down Expand Up @@ -981,7 +982,7 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}
Expand Down
17 changes: 16 additions & 1 deletion datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub enum TypeSignature {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ArrayFunctionSignature {
/// Specialized Signature for ArrayAppend and similar functions
/// The first argument should be List/LargeList, and the second argument should be non-list or list.
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list.
/// The second argument's list dimension should be one dimension less than the first argument's list dimension.
/// List dimension of the List/LargeList is equivalent to the number of List.
/// List dimension of the non-list is 0.
Expand All @@ -133,9 +133,14 @@ pub enum ArrayFunctionSignature {
/// The first argument's list dimension should be one dimension less than the second argument's list dimension.
ElementAndArray,
/// Specialized Signature for Array functions of the form (List/LargeList, Index)
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be Int64.
ArrayAndIndex,
/// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index)
ArrayAndElementAndOptionalIndex,
/// Specialized Signature for ArrayEmpty and similar functions
Weijun-H marked this conversation as resolved.
Show resolved Hide resolved
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
}

impl std::fmt::Display for ArrayFunctionSignature {
Expand All @@ -153,6 +158,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::ArrayAndIndex => {
write!(f, "array, index")
}
ArrayFunctionSignature::Array => {
write!(f, "array")
}
}
}
}
Expand Down Expand Up @@ -325,6 +333,13 @@ impl Signature {
volatility,
}
}
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
volatility,
}
}
}

/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments.
Expand Down
109 changes: 59 additions & 50 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
fn array_element_and_optional_index(
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
// make sure there's 2 or 3 arguments
if !(current_types.len() == 2 || current_types.len() == 3) {
return Ok(vec![vec![]]);
}

let first_two_types = &current_types[0..2];
let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?;

// Early return if there are only 2 arguments
if current_types.len() == 2 {
return Ok(valid_types);
}

let valid_types_with_index = valid_types
.iter()
.map(|t| {
let mut t = t.clone();
t.push(DataType::Int64);
t
})
.collect::<Vec<_>>();

valid_types.extend(valid_types_with_index);

Ok(valid_types)
}

fn array_append_or_prepend_valid_types(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see -- so one part of this PR is extracting this function into ArrayFunctionSignature ?

current_types: &[DataType],
is_append: bool,
Expand Down Expand Up @@ -111,71 +141,37 @@ fn get_valid_types(
)
})?;

let array_type = datafusion_common::utils::coerced_type_with_base_type_only(
let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
array_type,
&new_base_type,
);

match array_type {
match new_array_type {
DataType::List(ref field)
| DataType::LargeList(ref field)
| DataType::FixedSizeList(ref field, _) => {
let elem_type = field.data_type();
let new_elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.clone()]])
Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]])
}
}
_ => Ok(vec![vec![]]),
}
}
fn array_element_and_optional_index(
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
// make sure there's 2 or 3 arguments
if !(current_types.len() == 2 || current_types.len() == 3) {
return Ok(vec![vec![]]);
}

let first_two_types = &current_types[0..2];
let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?;

// Early return if there are only 2 arguments
if current_types.len() == 2 {
return Ok(valid_types);
}

let valid_types_with_index = valid_types
.iter()
.map(|t| {
let mut t = t.clone();
t.push(DataType::Int64);
t
})
.collect::<Vec<_>>();

valid_types.extend(valid_types_with_index);

Ok(valid_types)
}
fn array_and_index(current_types: &[DataType]) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != 2 {
return Ok(vec![vec![]]);
}

let array_type = &current_types[0];

fn array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => {
let array_type = coerced_fixed_size_list_to_list(array_type);
Ok(vec![vec![array_type, DataType::Int64]])
Some(array_type)
}
_ => Ok(vec![vec![]]),
_ => None,
}
}

let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
Expand Down Expand Up @@ -211,19 +207,32 @@ fn get_valid_types(
TypeSignature::ArraySignature(ref function_signature) => match function_signature
{
ArrayFunctionSignature::ArrayAndElement => {
return array_append_or_prepend_valid_types(current_types, true)
array_append_or_prepend_valid_types(current_types, true)?
}
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
return array_element_and_optional_index(current_types)
ArrayFunctionSignature::ElementAndArray => {
array_append_or_prepend_valid_types(current_types, false)?
}
ArrayFunctionSignature::ArrayAndIndex => {
return array_and_index(current_types)
if current_types.len() != 2 {
return Ok(vec![vec![]]);
}
array(&current_types[0]).map_or_else(
|| vec![vec![]],
|array_type| vec![vec![array_type, DataType::Int64]],
)
}
ArrayFunctionSignature::ElementAndArray => {
return array_append_or_prepend_valid_types(current_types, false)
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
array_element_and_optional_index(current_types)?
}
},
ArrayFunctionSignature::Array => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}

array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
},
TypeSignature::Any(number) => {
if current_types.len() != *number {
return plan_err!(
Expand Down
Loading
Loading