diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index eaedc30289989..0d431f778ebc8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -924,51 +924,63 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayAppend => { - Signature::array_and_element(self.volatility()) + Signature::array_and_element(false, self.volatility()) } BuiltinScalarFunction::MakeArray => { // 0 or more arguments of arbitrary type Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) } - BuiltinScalarFunction::ArrayPopFront => Signature::array(self.volatility()), - BuiltinScalarFunction::ArrayPopBack => Signature::array(self.volatility()), + BuiltinScalarFunction::ArrayPopFront => { + Signature::array(false, self.volatility()) + } + BuiltinScalarFunction::ArrayPopBack => { + Signature::array(false, self.volatility()) + } BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()), - BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()), + BuiltinScalarFunction::ArrayDims => { + Signature::array(false, self.volatility()) + } + BuiltinScalarFunction::ArrayEmpty => { + Signature::array(false, self.volatility()) + } BuiltinScalarFunction::ArrayElement => { - Signature::array_and_index(self.volatility()) + Signature::array_and_index(false, self.volatility()) } BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), - BuiltinScalarFunction::Flatten => Signature::array(self.volatility()), + BuiltinScalarFunction::Flatten => Signature::array(false, self.volatility()), BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny => { Signature::any(2, self.volatility()) } BuiltinScalarFunction::ArrayHas => { - Signature::array_and_element(self.volatility()) + Signature::array_and_element(false, self.volatility()) } BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()), - BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()), + BuiltinScalarFunction::ArrayNdims => { + Signature::array(false, self.volatility()) + } + BuiltinScalarFunction::ArrayDistinct => { + Signature::array(true, self.volatility()) + } BuiltinScalarFunction::ArrayPosition => { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => { - Signature::array_and_element(self.volatility()) + Signature::array_and_element(false, self.volatility()) } BuiltinScalarFunction::ArrayPrepend => { - Signature::element_and_array(self.volatility()) + Signature::element_and_array(false, self.volatility()) } BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => { - Signature::array_and_element(self.volatility()) + Signature::array_and_element(false, self.volatility()) } BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayRemoveAll => { - Signature::array_and_element(self.volatility()) + Signature::array_and_element(false, self.volatility()) } BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), @@ -985,7 +997,9 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), - BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()), + BuiltinScalarFunction::Cardinality => { + Signature::array(false, self.volatility()) + } BuiltinScalarFunction::ArrayResize => { Signature::variadic_any(self.volatility()) } diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 7e4aaf5705a38..79f1a95281a9b 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -122,7 +122,8 @@ pub enum TypeSignature { /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), /// Specifies Signatures for array functions - ArraySignature(ArrayFunctionSignature), + /// Boolean value specifies whether null type coercion is allowed + ArraySignature(ArrayFunctionSignature, bool), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -144,13 +145,19 @@ pub enum ArrayFunctionSignature { } impl ArrayFunctionSignature { + /// Arguments to ArrayFunctionSignature + /// `current_types` - The data types of the arguments + /// `coercion` - Whether null type coercion is allowed + /// Returns the valid types for the function signature pub fn get_type_signature( &self, current_types: &[DataType], + allow_null_coercion: bool, ) -> Result>> { fn array_append_or_prepend_valid_types( current_types: &[DataType], is_append: bool, + allow_null_coercion: bool, ) -> Result>> { if current_types.len() != 2 { return Ok(vec![vec![]]); @@ -163,7 +170,7 @@ impl ArrayFunctionSignature { }; // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { + if array_type.eq(&DataType::Null) && !allow_null_coercion { return Ok(vec![vec![]]); } @@ -215,8 +222,13 @@ impl ArrayFunctionSignature { _ => Ok(vec![vec![]]), } } - fn array(current_types: &[DataType]) -> Result>> { - if current_types.len() != 1 { + fn array( + current_types: &[DataType], + allow_null_coercion: bool, + ) -> Result>> { + if current_types.len() != 1 + || (current_types[0].is_null() && !allow_null_coercion) + { return Ok(vec![vec![]]); } @@ -229,7 +241,6 @@ impl ArrayFunctionSignature { let array_type = coerced_fixed_size_list_to_list(array_type); Ok(vec![vec![array_type]]) } - DataType::Null => Ok(vec![vec![array_type.to_owned()]]), _ => Ok(vec![vec![DataType::List(Arc::new(Field::new( "item", array_type.to_owned(), @@ -239,13 +250,21 @@ impl ArrayFunctionSignature { } match self { ArrayFunctionSignature::ArrayAndElement => { - array_append_or_prepend_valid_types(current_types, true) + array_append_or_prepend_valid_types( + current_types, + true, + allow_null_coercion, + ) } ArrayFunctionSignature::ElementAndArray => { - array_append_or_prepend_valid_types(current_types, false) + array_append_or_prepend_valid_types( + current_types, + false, + allow_null_coercion, + ) } ArrayFunctionSignature::ArrayAndIndex => array_and_index(current_types), - ArrayFunctionSignature::Array => array(current_types), + ArrayFunctionSignature::Array => array(current_types, allow_null_coercion), } } } @@ -297,7 +316,7 @@ impl TypeSignature { TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } - TypeSignature::ArraySignature(array_signature) => { + TypeSignature::ArraySignature(array_signature, _) => { vec![array_signature.to_string()] } } @@ -402,36 +421,42 @@ impl Signature { } } /// Specialized Signature for ArrayAppend and similar functions - pub fn array_and_element(volatility: Volatility) -> Self { + pub fn array_and_element(allow_null_coercion: bool, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( ArrayFunctionSignature::ArrayAndElement, + allow_null_coercion, ), volatility, } } /// Specialized Signature for ArrayPrepend and similar functions - pub fn element_and_array(volatility: Volatility) -> Self { + pub fn element_and_array(allow_null_coercion: bool, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( ArrayFunctionSignature::ElementAndArray, + allow_null_coercion, ), volatility, } } /// Specialized Signature for ArrayElement and similar functions - pub fn array_and_index(volatility: Volatility) -> Self { + pub fn array_and_index(allow_null_coercion: bool, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( ArrayFunctionSignature::ArrayAndIndex, + allow_null_coercion, ), volatility, } } /// Specialized Signature for ArrayEmpty and similar functions - pub fn array(volatility: Volatility) -> Self { + pub fn array(allow_null_coercion: bool, volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array, + allow_null_coercion, + ), volatility, } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 70015c6992966..a54e88dd879f8 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -301,7 +301,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option Option Option { +fn allow_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { match (lhs_type, rhs_type) { (DataType::Null, other_type) | (other_type, DataType::Null) => { if can_cast_types(&DataType::Null, other_type) { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 4e5a2f1b69552..b0054aa28e161 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -110,8 +110,8 @@ fn get_valid_types( } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], - TypeSignature::ArraySignature(ref function_signature) => { - function_signature.get_type_signature(current_types)? + TypeSignature::ArraySignature(ref function_signature, allow_null_coercion) => { + function_signature.get_type_signature(current_types, *allow_null_coercion)? } TypeSignature::Any(number) => { diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 844dae0917c72..2e5dfac574bea 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2659,7 +2659,7 @@ pub fn array_distinct(args: &[ArrayRef]) -> Result { } // handle for list & largelist - match args[0].data_type() { + match dbg!(args[0].data_type()) { DataType::List(field) => { let array = as_list_array(&args[0])?; general_array_distinct(array, field) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 7ff43b7d6ddf3..fc2293b9acd18 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4253,7 +4253,10 @@ NULL [3] [4] # array_ndims scalar function #1 query error -selrct array_ndims(1), array_ndims(null) +select array_ndims(1) + +query error +select array_ndims(null) query I select