From 386c7ed8337e00b3e099c566db5c5a6be803962c Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 6 Feb 2025 20:18:48 +0800 Subject: [PATCH 01/18] coerciblev2 Signed-off-by: Jay Zhan --- datafusion/expr-common/src/signature.rs | 54 ++++++++++++++ .../expr/src/type_coercion/functions.rs | 71 +++++++++++++++++++ datafusion/expr/src/udf.rs | 1 + datafusion/functions/src/string/ascii.rs | 12 +++- 4 files changed, 136 insertions(+), 2 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 1bfae28af840..6408056a5190 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -133,6 +133,7 @@ pub enum TypeSignature { /// /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. Coercible(Vec), + CoercibleV2(Vec), /// One or more arguments coercible to a single, comparable type. /// /// Each argument will be coerced to a single type using the @@ -313,6 +314,10 @@ impl TypeSignature { TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] } + TypeSignature::CoercibleV2(param_types) => { + // todo!("123") + vec![Self::join_types(param_types, ", ")] + } TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] } @@ -426,6 +431,7 @@ impl TypeSignature { | TypeSignature::Nullary | TypeSignature::VariadicAny | TypeSignature::ArraySignature(_) + | TypeSignature::CoercibleV2(_) | TypeSignature::UserDefined => vec![], } } @@ -460,6 +466,46 @@ fn get_data_types(native_type: &NativeType) -> Vec { } } +#[derive(Debug, Clone)] +pub struct FunctionSignature { + pub parameters: Vec, + /// The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, +} + +pub type ParameterSignature = Vec; + +#[derive(Debug, Clone, Eq, PartialOrd, Hash)] +pub struct Coercion { + pub desired_type: TypeSignatureClass, + pub allowed_casts: Vec, +} + +impl Display for Coercion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ParameterType({}", self.desired_type)?; + if !self.allowed_casts.is_empty() { + write!( + f, + ", allowed_casts=[{}]", + self.allowed_casts + .iter() + .map(|cast| cast.to_string()) + .join(", ") + ) + } else { + write!(f, ")") + } + } +} + +impl PartialEq for Coercion { + fn eq(&self, other: &Self) -> bool { + self.desired_type == other.desired_type + && self.allowed_casts == other.allowed_casts + } +} + /// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. /// /// DataFusion will automatically coerce (cast) argument types to one of the supported @@ -547,6 +593,14 @@ impl Signature { } } + /// Target coerce types in order + pub fn coercible_v2(target_types: Vec, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::CoercibleV2(target_types), + volatility, + } + } + /// Used for function that expects comparable data types, it will try to coerced all the types into single final one. pub fn comparable(arg_count: usize, volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 0f9dbec722c2..481535d44e50 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -18,6 +18,7 @@ use super::binary::{binary_numeric_coercion, comparison_coercion}; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use arrow::{ + array::cast, compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; @@ -593,6 +594,76 @@ fn get_valid_types( vec![vec![target_type; *num]] } } + TypeSignature::CoercibleV2(param_types) => { + function_length_check(current_types.len(), param_types.len())?; + + let mut new_types = Vec::with_capacity(current_types.len()); + for (current_type, param) in current_types.iter().zip(param_types.iter()) { + let current_logical_type: NativeType = current_type.into(); + + fn can_coerce_to( + desired_type: &TypeSignatureClass, + logical_type: &NativeType, + current_type: &DataType, + ) -> Option { + match desired_type { + TypeSignatureClass::Native(t) if t.native() == logical_type => { + t.native().default_cast_for(current_type).ok() + } + TypeSignatureClass::Native(t) + if logical_type == &NativeType::Null => + { + t.native().default_cast_for(current_type).ok() + } + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + TypeSignatureClass::Timestamp + if logical_type == &NativeType::String => + { + Some(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + Some(current_type.to_owned()) + } + TypeSignatureClass::Date if logical_type.is_date() => { + Some(current_type.to_owned()) + } + TypeSignatureClass::Time if logical_type.is_time() => { + Some(current_type.to_owned()) + } + TypeSignatureClass::Interval if logical_type.is_interval() => { + Some(current_type.to_owned()) + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + Some(current_type.to_owned()) + } + + _ => None, + } + } + + if let Some(casted_type) = can_coerce_to( + ¶m.desired_type, + ¤t_logical_type, + current_type, + ) + .or_else(|| { + param.allowed_casts.iter().find_map(|t| { + can_coerce_to(t, ¤t_logical_type, current_type) + }) + }) { + new_types.push(casted_type); + } else { + return internal_err!( + "Expect {} but received NativeType: {}, DataType: {}", + param.desired_type, + current_logical_type, + current_type + ); + } + } + + vec![new_types] + } TypeSignature::Coercible(target_types) => { function_length_check( function_name, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 7c91b6b3b4ab..605afa8a96c9 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -26,6 +26,7 @@ use crate::{ use arrow::datatypes::DataType; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::signature::FunctionSignature; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 858eddc7c8f8..17dff95580da 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -19,9 +19,11 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; use arrow::error::ArrowError; +use datafusion_common::types::{logical_binary, logical_string}; use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::signature::Coercion; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -61,7 +63,13 @@ impl Default for AsciiFunc { impl AsciiFunc { pub fn new() -> Self { Self { - signature: Signature::string(1, Volatility::Immutable), + signature: Signature::coercible_v2( + vec![Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![TypeSignatureClass::Native(logical_binary())], + }], + Volatility::Immutable, + ), } } } From 579ffef90202f617d79c1124e6302783d4ba5bfc Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 6 Feb 2025 21:06:27 +0800 Subject: [PATCH 02/18] repeat Signed-off-by: Jay Zhan --- datafusion/expr-common/src/signature.rs | 24 +++++- .../expr/src/type_coercion/functions.rs | 75 +++++++++++-------- datafusion/expr/src/udf.rs | 1 - datafusion/functions/src/string/repeat.rs | 15 +++- datafusion/sqllogictest/test_files/expr.slt | 2 +- 5 files changed, 77 insertions(+), 40 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 6408056a5190..5feeab9f2419 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -23,7 +23,8 @@ use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::types::{LogicalTypeRef, NativeType}; +use datafusion_common::Result; +use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -217,7 +218,7 @@ pub enum TypeSignatureClass { Native(LogicalTypeRef), // TODO: // Numeric - // Integer + Integer, } impl Display for TypeSignatureClass { @@ -226,6 +227,22 @@ impl Display for TypeSignatureClass { } } +impl TypeSignatureClass { + /// Return the default casted type for the given `TypeSignatureClass` + /// We return the largest common type for the given `TypeSignatureClass` + pub fn default_casted_type(&self, data_type: &DataType) -> Result { + Ok(match self { + TypeSignatureClass::Native(logical_type) => return logical_type.native().default_cast_for(data_type), + TypeSignatureClass::Timestamp => DataType::Timestamp(TimeUnit::Nanosecond, None), + TypeSignatureClass::Date => DataType::Date64, + TypeSignatureClass::Time => DataType::Time64(TimeUnit::Nanosecond), + TypeSignatureClass::Interval => DataType::Interval(IntervalUnit::DayTime), + TypeSignatureClass::Duration => DataType::Duration(TimeUnit::Nanosecond), + TypeSignatureClass::Integer => DataType::Int64, + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions @@ -408,6 +425,9 @@ impl TypeSignature { TypeSignatureClass::Duration => { vec![DataType::Duration(TimeUnit::Nanosecond)] } + TypeSignatureClass::Integer => { + vec![DataType::Int64] + } }) .multi_cartesian_product() .collect(), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 481535d44e50..dc8d19aa136f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -18,7 +18,6 @@ use super::binary::{binary_numeric_coercion, comparison_coercion}; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use arrow::{ - array::cast, compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; @@ -601,65 +600,77 @@ fn get_valid_types( for (current_type, param) in current_types.iter().zip(param_types.iter()) { let current_logical_type: NativeType = current_type.into(); - fn can_coerce_to( - desired_type: &TypeSignatureClass, + fn is_matched_type( + target_type: &TypeSignatureClass, logical_type: &NativeType, - current_type: &DataType, - ) -> Option { - match desired_type { + ) -> bool { + match target_type { TypeSignatureClass::Native(t) if t.native() == logical_type => { - t.native().default_cast_for(current_type).ok() + true } - TypeSignatureClass::Native(t) + TypeSignatureClass::Native(_) if logical_type == &NativeType::Null => { - t.native().default_cast_for(current_type).ok() + true } // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp TypeSignatureClass::Timestamp if logical_type == &NativeType::String => { - Some(DataType::Timestamp(TimeUnit::Nanosecond, None)) + true } TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { - Some(current_type.to_owned()) - } - TypeSignatureClass::Date if logical_type.is_date() => { - Some(current_type.to_owned()) - } - TypeSignatureClass::Time if logical_type.is_time() => { - Some(current_type.to_owned()) + true } + TypeSignatureClass::Date if logical_type.is_date() => true, + TypeSignatureClass::Time if logical_type.is_time() => true, TypeSignatureClass::Interval if logical_type.is_interval() => { - Some(current_type.to_owned()) + true } TypeSignatureClass::Duration if logical_type.is_duration() => { - Some(current_type.to_owned()) + true } - - _ => None, + TypeSignatureClass::Integer if logical_type.is_integer() => true, + _ => false, } } - if let Some(casted_type) = can_coerce_to( - ¶m.desired_type, - ¤t_logical_type, - current_type, - ) - .or_else(|| { - param.allowed_casts.iter().find_map(|t| { - can_coerce_to(t, ¤t_logical_type, current_type) - }) - }) { + if is_matched_type(¶m.desired_type, ¤t_logical_type) + || param + .allowed_casts + .iter() + .any(|t| is_matched_type(t, ¤t_logical_type)) + { + let casted_type = param.desired_type.default_casted_type(current_type)?; new_types.push(casted_type); } else { return internal_err!( - "Expect {} but received NativeType: {}, DataType: {}", + "Expect {} but received {}, DataType: {}", param.desired_type, current_logical_type, current_type ); } + + // if let Some(casted_type) = get_casted_type( + // ¶m.desired_type, + // ¤t_logical_type, + // current_type, + // ) + // .or_else(|| { + // param.allowed_casts.iter().find_map(|t| { + // get_casted_type(t, ¤t_logical_type, current_type) + // }) + // }) { + // new_types.push(casted_type); + // } else { + // return internal_err!( + // "Expect {} but received NativeType: {}, DataType: {}", + // param.desired_type, + // current_logical_type, + // current_type + // ); + // } } vec![new_types] diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 605afa8a96c9..7c91b6b3b4ab 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -26,7 +26,6 @@ use crate::{ use arrow::datatypes::DataType; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_expr_common::signature::FunctionSignature; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index eea9af2ba749..e7d86536894e 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -30,7 +30,7 @@ use datafusion_common::types::{logical_int64, logical_string}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; #[user_doc( @@ -65,10 +65,17 @@ impl Default for RepeatFunc { impl RepeatFunc { pub fn new() -> Self { Self { - signature: Signature::coercible( + signature: Signature::coercible_v2( vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Native(logical_int64()), + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + // Accept all integer types but cast them to i64 + Coercion { + desired_type: TypeSignatureClass::Native(logical_int64()), + allowed_casts: vec![TypeSignatureClass::Integer], + }, ], Volatility::Immutable, ), diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index a0264c43622f..b5f88212b036 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -571,7 +571,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Internal error: Function 'repeat' expects TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received Float64 +query error Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64. select repeat('-1.2', 3.2); query T From 104da4346fec8c97f04cbbb3f8727249055cf9c2 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 6 Feb 2025 21:29:50 +0800 Subject: [PATCH 03/18] fix possible types --- datafusion/expr-common/src/signature.rs | 69 ++++++++++++++++--- .../expr/src/type_coercion/functions.rs | 23 +------ 2 files changed, 61 insertions(+), 31 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 5feeab9f2419..07369ab6d1fe 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -23,8 +23,8 @@ use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::Result; use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; +use datafusion_common::{HashSet, Result}; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -232,14 +232,18 @@ impl TypeSignatureClass { /// We return the largest common type for the given `TypeSignatureClass` pub fn default_casted_type(&self, data_type: &DataType) -> Result { Ok(match self { - TypeSignatureClass::Native(logical_type) => return logical_type.native().default_cast_for(data_type), - TypeSignatureClass::Timestamp => DataType::Timestamp(TimeUnit::Nanosecond, None), - TypeSignatureClass::Date => DataType::Date64, - TypeSignatureClass::Time => DataType::Time64(TimeUnit::Nanosecond), - TypeSignatureClass::Interval => DataType::Interval(IntervalUnit::DayTime), - TypeSignatureClass::Duration => DataType::Duration(TimeUnit::Nanosecond), - TypeSignatureClass::Integer => DataType::Int64, - }) + TypeSignatureClass::Native(logical_type) => { + return logical_type.native().default_cast_for(data_type) + } + TypeSignatureClass::Timestamp => { + DataType::Timestamp(TimeUnit::Nanosecond, None) + } + TypeSignatureClass::Date => DataType::Date64, + TypeSignatureClass::Time => DataType::Time64(TimeUnit::Nanosecond), + TypeSignatureClass::Interval => DataType::Interval(IntervalUnit::DayTime), + TypeSignatureClass::Duration => DataType::Duration(TimeUnit::Nanosecond), + TypeSignatureClass::Integer => DataType::Int64, + }) } } @@ -400,6 +404,23 @@ impl TypeSignature { .cloned() .map(|data_type| vec![data_type; *arg_count]) .collect(), + TypeSignature::CoercibleV2(coercions) => coercions + .iter() + .map(|c| { + let mut all_types: HashSet = + get_possible_types_from_signature_classes(&c.desired_type) + .into_iter() + .collect(); + let allowed_casts: Vec = c + .allowed_casts + .iter() + .flat_map(|t| get_possible_types_from_signature_classes(t)) + .collect(); + all_types.extend(allowed_casts.into_iter()); + all_types.into_iter().collect::>() + }) + .multi_cartesian_product() + .collect(), TypeSignature::Coercible(types) => types .iter() .map(|logical_type| match logical_type { @@ -451,12 +472,40 @@ impl TypeSignature { | TypeSignature::Nullary | TypeSignature::VariadicAny | TypeSignature::ArraySignature(_) - | TypeSignature::CoercibleV2(_) | TypeSignature::UserDefined => vec![], } } } +fn get_possible_types_from_signature_classes( + signature_classes: &TypeSignatureClass, +) -> Vec { + match signature_classes { + TypeSignatureClass::Native(l) => get_data_types(l.native()), + TypeSignatureClass::Timestamp => { + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ] + } + TypeSignatureClass::Date => { + vec![DataType::Date64] + } + TypeSignatureClass::Time => { + vec![DataType::Time64(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Interval => { + vec![DataType::Interval(IntervalUnit::DayTime)] + } + TypeSignatureClass::Duration => { + vec![DataType::Duration(TimeUnit::Nanosecond)] + } + TypeSignatureClass::Integer => { + vec![DataType::Int64] + } + } +} + fn get_data_types(native_type: &NativeType) -> Vec { match native_type { NativeType::Null => vec![DataType::Null], diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index dc8d19aa136f..feea07a9193c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -641,7 +641,8 @@ fn get_valid_types( .iter() .any(|t| is_matched_type(t, ¤t_logical_type)) { - let casted_type = param.desired_type.default_casted_type(current_type)?; + let casted_type = + param.desired_type.default_casted_type(current_type)?; new_types.push(casted_type); } else { return internal_err!( @@ -651,26 +652,6 @@ fn get_valid_types( current_type ); } - - // if let Some(casted_type) = get_casted_type( - // ¶m.desired_type, - // ¤t_logical_type, - // current_type, - // ) - // .or_else(|| { - // param.allowed_casts.iter().find_map(|t| { - // get_casted_type(t, ¤t_logical_type, current_type) - // }) - // }) { - // new_types.push(casted_type); - // } else { - // return internal_err!( - // "Expect {} but received NativeType: {}, DataType: {}", - // param.desired_type, - // current_logical_type, - // current_type - // ); - // } } vec![new_types] From aae48ff2813e1451f7876d9021496211bc544848 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 6 Feb 2025 21:46:15 +0800 Subject: [PATCH 04/18] replace all coerciblev1 --- datafusion/expr-common/Cargo.toml | 1 + datafusion/expr-common/src/signature.rs | 25 ++++---- .../expr/src/type_coercion/functions.rs | 6 +- .../functions/src/datetime/date_part.rs | 62 ++++++++++++++----- 4 files changed, 65 insertions(+), 29 deletions(-) diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 109d8e0b89a6..e8d808052bfd 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -40,4 +40,5 @@ path = "src/lib.rs" arrow = { workspace = true } datafusion-common = { workspace = true } itertools = { workspace = true } +indexmap = { workspace = true } paste = "^1.0" diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 07369ab6d1fe..cbc6fe412b61 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -24,7 +24,8 @@ use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; -use datafusion_common::{HashSet, Result}; +use datafusion_common::Result; +use indexmap::IndexSet; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -230,19 +231,21 @@ impl Display for TypeSignatureClass { impl TypeSignatureClass { /// Return the default casted type for the given `TypeSignatureClass` /// We return the largest common type for the given `TypeSignatureClass` - pub fn default_casted_type(&self, data_type: &DataType) -> Result { + pub fn default_casted_type( + &self, + logical_type: &NativeType, + data_type: &DataType, + ) -> Result { Ok(match self { + // TODO: Able to elimnate this special case? + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + TypeSignatureClass::Timestamp if logical_type == &NativeType::String => { + DataType::Timestamp(TimeUnit::Nanosecond, None) + } TypeSignatureClass::Native(logical_type) => { return logical_type.native().default_cast_for(data_type) } - TypeSignatureClass::Timestamp => { - DataType::Timestamp(TimeUnit::Nanosecond, None) - } - TypeSignatureClass::Date => DataType::Date64, - TypeSignatureClass::Time => DataType::Time64(TimeUnit::Nanosecond), - TypeSignatureClass::Interval => DataType::Interval(IntervalUnit::DayTime), - TypeSignatureClass::Duration => DataType::Duration(TimeUnit::Nanosecond), - TypeSignatureClass::Integer => DataType::Int64, + _ => data_type.clone(), }) } } @@ -407,7 +410,7 @@ impl TypeSignature { TypeSignature::CoercibleV2(coercions) => coercions .iter() .map(|c| { - let mut all_types: HashSet = + let mut all_types: IndexSet = get_possible_types_from_signature_classes(&c.desired_type) .into_iter() .collect(); diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index feea07a9193c..82794694f7f4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -209,6 +209,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { | TypeSignature::Numeric(_) | TypeSignature::String(_) | TypeSignature::Coercible(_) + | TypeSignature::CoercibleV2(_) | TypeSignature::Any(_) | TypeSignature::Nullary | TypeSignature::Comparable(_) @@ -641,8 +642,9 @@ fn get_valid_types( .iter() .any(|t| is_matched_type(t, ¤t_logical_type)) { - let casted_type = - param.desired_type.default_casted_type(current_type)?; + let casted_type = param + .desired_type + .default_casted_type(¤t_logical_type, current_type)?; new_types.push(casted_type); } else { return internal_err!( diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index c7dbf089e530..de6732c832aa 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -45,7 +45,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use datafusion_expr_common::signature::TypeSignatureClass; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; #[user_doc( @@ -95,25 +95,55 @@ impl DatePartFunc { Self { signature: Signature::one_of( vec![ - TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Timestamp, + TypeSignature::CoercibleV2(vec![ + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + Coercion { + desired_type: TypeSignatureClass::Timestamp, + allowed_casts: vec![], + }, ]), - TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Date, + TypeSignature::CoercibleV2(vec![ + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + Coercion { + desired_type: TypeSignatureClass::Date, + allowed_casts: vec![], + }, ]), - TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Time, + TypeSignature::CoercibleV2(vec![ + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + Coercion { + desired_type: TypeSignatureClass::Time, + allowed_casts: vec![], + }, ]), - TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Interval, + TypeSignature::CoercibleV2(vec![ + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + Coercion { + desired_type: TypeSignatureClass::Interval, + allowed_casts: vec![], + }, ]), - TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Duration, + TypeSignature::CoercibleV2(vec![ + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + Coercion { + desired_type: TypeSignatureClass::Duration, + allowed_casts: vec![], + }, ]), ], Volatility::Immutable, From f585136fd3e8f7d99e031ee16d4f76b732eb2cd8 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 08:51:56 +0800 Subject: [PATCH 05/18] cleanup --- datafusion-cli/Cargo.lock | 1 + datafusion/expr-common/Cargo.toml | 2 +- datafusion/expr-common/src/signature.rs | 26 +++++++++---------- .../expr/src/type_coercion/functions.rs | 2 +- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e320b2ffc835..08c99ad18434 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1403,6 +1403,7 @@ version = "45.0.0" dependencies = [ "arrow", "datafusion-common", + "indexmap", "itertools 0.14.0", "paste", ] diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index e8d808052bfd..abc78a9f084b 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -39,6 +39,6 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } -itertools = { workspace = true } indexmap = { workspace = true } +itertools = { workspace = true } paste = "^1.0" diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index cbc6fe412b61..c76abf152ef0 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,6 +19,7 @@ //! and return types of functions in DataFusion. use std::fmt::Display; +use std::hash::Hash; use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; @@ -339,7 +340,6 @@ impl TypeSignature { vec![Self::join_types(types, ", ")] } TypeSignature::CoercibleV2(param_types) => { - // todo!("123") vec![Self::join_types(param_types, ", ")] } TypeSignature::Exact(types) => { @@ -417,9 +417,9 @@ impl TypeSignature { let allowed_casts: Vec = c .allowed_casts .iter() - .flat_map(|t| get_possible_types_from_signature_classes(t)) + .flat_map(get_possible_types_from_signature_classes) .collect(); - all_types.extend(allowed_casts.into_iter()); + all_types.extend(allowed_casts); all_types.into_iter().collect::>() }) .multi_cartesian_product() @@ -538,16 +538,7 @@ fn get_data_types(native_type: &NativeType) -> Vec { } } -#[derive(Debug, Clone)] -pub struct FunctionSignature { - pub parameters: Vec, - /// The volatility of the function. See [Volatility] for more information. - pub volatility: Volatility, -} - -pub type ParameterSignature = Vec; - -#[derive(Debug, Clone, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, Eq, PartialOrd)] pub struct Coercion { pub desired_type: TypeSignatureClass, pub allowed_casts: Vec, @@ -555,7 +546,7 @@ pub struct Coercion { impl Display for Coercion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ParameterType({}", self.desired_type)?; + write!(f, "Coercion({}", self.desired_type)?; if !self.allowed_casts.is_empty() { write!( f, @@ -578,6 +569,13 @@ impl PartialEq for Coercion { } } +impl Hash for Coercion { + fn hash(&self, state: &mut H) { + self.desired_type.hash(state); + self.allowed_casts.hash(state); + } +} + /// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. /// /// DataFusion will automatically coerce (cast) argument types to one of the supported diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 82794694f7f4..df79851b755c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -595,7 +595,7 @@ fn get_valid_types( } } TypeSignature::CoercibleV2(param_types) => { - function_length_check(current_types.len(), param_types.len())?; + function_length_check(function_name, current_types.len(), param_types.len())?; let mut new_types = Vec::with_capacity(current_types.len()); for (current_type, param) in current_types.iter().zip(param_types.iter()) { From bad7348d00d080f7fdcb72d67c5dd266dbe00eba Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 09:41:27 +0800 Subject: [PATCH 06/18] remove specialize logic --- datafusion/common/src/types/builtin.rs | 10 ++++ datafusion/expr-common/src/signature.rs | 46 +++++++++++++------ .../expr/src/type_coercion/functions.rs | 18 ++------ .../functions/src/datetime/date_part.rs | 5 +- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index ec69db790377..10e841147e4a 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::TimeUnit; + use crate::types::{LogicalTypeRef, NativeType}; use std::sync::{Arc, LazyLock}; @@ -47,3 +49,11 @@ singleton!(LOGICAL_FLOAT64, logical_float64, Float64); singleton!(LOGICAL_DATE, logical_date, Date); singleton!(LOGICAL_BINARY, logical_binary, Binary); singleton!(LOGICAL_STRING, logical_string, String); + +// TODO: Extend macro +// TODO: Should we use LOGICAL_TIMESTAMP_NANO to distinguish unit and timzeone? +static LOGICAL_TIMESTAMP: LazyLock = + LazyLock::new(|| Arc::new(NativeType::Timestamp(TimeUnit::Nanosecond, None))); +pub fn logical_timestamp() -> LogicalTypeRef { + Arc::clone(&LOGICAL_TIMESTAMP) +} diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index c76abf152ef0..54731db8e784 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -24,8 +24,10 @@ use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; -use datafusion_common::Result; +use datafusion_common::types::{ + logical_timestamp, LogicalType, LogicalTypeRef, NativeType, +}; +use datafusion_common::{not_impl_err, Result}; use indexmap::IndexSet; use itertools::Itertools; @@ -230,24 +232,40 @@ impl Display for TypeSignatureClass { } impl TypeSignatureClass { - /// Return the default casted type for the given `TypeSignatureClass` - /// We return the largest common type for the given `TypeSignatureClass` + /// Returns the default cast type for the given `TypeSignatureClass`. + /// Be cautious to avoid adding specialized logic here, as this function is public and intended for general use. pub fn default_casted_type( &self, logical_type: &NativeType, - data_type: &DataType, + origin_type: &DataType, ) -> Result { - Ok(match self { - // TODO: Able to elimnate this special case? - // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - TypeSignatureClass::Timestamp if logical_type == &NativeType::String => { - DataType::Timestamp(TimeUnit::Nanosecond, None) - } + match self { TypeSignatureClass::Native(logical_type) => { - return logical_type.native().default_cast_for(data_type) + logical_type.native().default_cast_for(origin_type) + } + // If the given type is already a timestamp, we don't change the unit and timezone + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Timestamp => { + // TODO: Consider allowing the user to specify the default timestamp type instead of having it predefined in DataFusion when we have such use case + // Use default timestamp type with nanosecond precision and no timezone + logical_timestamp().default_cast_for(origin_type) + } + TypeSignatureClass::Date if logical_type.is_date() => { + Ok(origin_type.to_owned()) } - _ => data_type.clone(), - }) + TypeSignatureClass::Time if logical_type.is_time() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Interval if logical_type.is_interval() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + Ok(origin_type.to_owned()) + } + _ => not_impl_err!("Other cases are not implemented yet"), + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index df79851b755c..9aef4482add0 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -605,21 +605,14 @@ fn get_valid_types( target_type: &TypeSignatureClass, logical_type: &NativeType, ) -> bool { + if logical_type == &NativeType::Null { + return true; + } + match target_type { TypeSignatureClass::Native(t) if t.native() == logical_type => { true } - TypeSignatureClass::Native(_) - if logical_type == &NativeType::Null => - { - true - } - // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - TypeSignatureClass::Timestamp - if logical_type == &NativeType::String => - { - true - } TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { true } @@ -636,8 +629,7 @@ fn get_valid_types( } } - if is_matched_type(¶m.desired_type, ¤t_logical_type) - || param + if is_matched_type(¶m.desired_type, ¤t_logical_type) || param .allowed_casts .iter() .any(|t| is_matched_type(t, ¤t_logical_type)) diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index de6732c832aa..b440d458bcaa 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -102,7 +102,10 @@ impl DatePartFunc { }, Coercion { desired_type: TypeSignatureClass::Timestamp, - allowed_casts: vec![], + // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp + allowed_casts: vec![TypeSignatureClass::Native( + logical_string(), + )], }, ]), TypeSignature::CoercibleV2(vec![ From c99e9861307dc1c222a7124e149cae0ec969e6e6 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 09:47:21 +0800 Subject: [PATCH 07/18] comment --- datafusion/expr-common/src/signature.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 54731db8e784..299ab75bdfdb 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -247,9 +247,10 @@ impl TypeSignatureClass { TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { Ok(origin_type.to_owned()) } - TypeSignatureClass::Timestamp => { - // TODO: Consider allowing the user to specify the default timestamp type instead of having it predefined in DataFusion when we have such use case - // Use default timestamp type with nanosecond precision and no timezone + // This is an existing use case for casting string to timestamp, since we don't have specific unit and timezone from string, + // so we use the default timestamp type with nanosecond precision and no timezone + // TODO: Consider allowing the user to specify the default timestamp type instead of having it predefined in DataFusion when we have more use cases + TypeSignatureClass::Timestamp if logical_type == &NativeType::String => { logical_timestamp().default_cast_for(origin_type) } TypeSignatureClass::Date if logical_type.is_date() => { From 07f97d07eb3461327bc5e28e36406ff7b6111e51 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 10:15:29 +0800 Subject: [PATCH 08/18] err msg --- datafusion/sqllogictest/test_files/expr.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index b5f88212b036..dbeb1df5e8aa 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -571,7 +571,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64. +query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64. select repeat('-1.2', 3.2); query T From da84394654f06ebcb84a0e6abf35418f254e8b52 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 10:44:43 +0800 Subject: [PATCH 09/18] ci escape --- datafusion/sqllogictest/test_files/expr.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index dbeb1df5e8aa..2ce7d93c1322 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -571,7 +571,7 @@ select repeat('-1.2', arrow_cast(3, 'Int32')); ---- -1.2-1.2-1.2 -query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64. +query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Native\(LogicalType\(Native\(Int64\), Int64\)\) but received NativeType::Float64, DataType: Float64 select repeat('-1.2', 3.2); query T From e0a889e1e0b75fa3b230bf1304be3c4d960c32d7 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 20:23:38 +0800 Subject: [PATCH 10/18] rm coerciblev1 Signed-off-by: Jay Zhan --- datafusion/expr-common/src/signature.rs | 70 ++++------------ .../expr/src/type_coercion/functions.rs | 80 +------------------ .../functions/src/datetime/date_part.rs | 10 +-- 3 files changed, 23 insertions(+), 137 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 299ab75bdfdb..3b78d3942219 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -131,14 +131,12 @@ pub enum TypeSignature { /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. - /// - /// For example, `Coercible(vec![logical_float64()])` accepts - /// arguments like `vec![Int32]` or `vec![Float32]` - /// since i32 and f32 can be cast to f64 + /// + /// [`Coercion`] contains not only the desired type but also the allowed casts. + /// For example, if you expect a function has string type, but you also allow it to be casted from binary type. /// /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. - Coercible(Vec), - CoercibleV2(Vec), + Coercible(Vec), /// One or more arguments coercible to a single, comparable type. /// /// Each argument will be coerced to a single type using the @@ -355,10 +353,7 @@ impl TypeSignature { TypeSignature::Comparable(num) => { vec![format!("Comparable({num})")] } - TypeSignature::Coercible(types) => { - vec![Self::join_types(types, ", ")] - } - TypeSignature::CoercibleV2(param_types) => { + TypeSignature::Coercible(param_types) => { vec![Self::join_types(param_types, ", ")] } TypeSignature::Exact(types) => { @@ -426,7 +421,7 @@ impl TypeSignature { .cloned() .map(|data_type| vec![data_type; *arg_count]) .collect(), - TypeSignature::CoercibleV2(coercions) => coercions + TypeSignature::Coercible(coercions) => coercions .iter() .map(|c| { let mut all_types: IndexSet = @@ -443,37 +438,6 @@ impl TypeSignature { }) .multi_cartesian_product() .collect(), - TypeSignature::Coercible(types) => types - .iter() - .map(|logical_type| match logical_type { - TypeSignatureClass::Native(l) => get_data_types(l.native()), - TypeSignatureClass::Timestamp => { - vec![ - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp( - TimeUnit::Nanosecond, - Some(TIMEZONE_WILDCARD.into()), - ), - ] - } - TypeSignatureClass::Date => { - vec![DataType::Date64] - } - TypeSignatureClass::Time => { - vec![DataType::Time64(TimeUnit::Nanosecond)] - } - TypeSignatureClass::Interval => { - vec![DataType::Interval(IntervalUnit::DayTime)] - } - TypeSignatureClass::Duration => { - vec![DataType::Duration(TimeUnit::Nanosecond)] - } - TypeSignatureClass::Integer => { - vec![DataType::Int64] - } - }) - .multi_cartesian_product() - .collect(), TypeSignature::Variadic(types) => types .iter() .cloned() @@ -671,21 +635,11 @@ impl Signature { volatility, } } - /// Target coerce types in order - pub fn coercible( - target_types: Vec, - volatility: Volatility, - ) -> Self { - Self { - type_signature: TypeSignature::Coercible(target_types), - volatility, - } - } /// Target coerce types in order pub fn coercible_v2(target_types: Vec, volatility: Volatility) -> Self { Self { - type_signature: TypeSignature::CoercibleV2(target_types), + type_signature: TypeSignature::Coercible(target_types), volatility, } } @@ -882,8 +836,14 @@ mod tests { ); let type_signature = TypeSignature::Coercible(vec![ - TypeSignatureClass::Native(logical_string()), - TypeSignatureClass::Native(logical_int64()), + Coercion { + desired_type: TypeSignatureClass::Native(logical_string()), + allowed_casts: vec![], + }, + Coercion { + desired_type: TypeSignatureClass::Native(logical_int64()), + allowed_casts: vec![], + }, ]); let possible_types = type_signature.get_possible_types(); assert_eq!( diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9aef4482add0..14553af192dc 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,8 +23,8 @@ use arrow::{ }; use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - types::{LogicalType, NativeType}, + exec_err, internal_datafusion_err, internal_err, plan_err, + types::NativeType, utils::list_ndims, Result, }; @@ -209,7 +209,6 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { | TypeSignature::Numeric(_) | TypeSignature::String(_) | TypeSignature::Coercible(_) - | TypeSignature::CoercibleV2(_) | TypeSignature::Any(_) | TypeSignature::Nullary | TypeSignature::Comparable(_) @@ -594,7 +593,7 @@ fn get_valid_types( vec![vec![target_type; *num]] } } - TypeSignature::CoercibleV2(param_types) => { + TypeSignature::Coercible(param_types) => { function_length_check(function_name, current_types.len(), param_types.len())?; let mut new_types = Vec::with_capacity(current_types.len()); @@ -650,79 +649,6 @@ fn get_valid_types( vec![new_types] } - TypeSignature::Coercible(target_types) => { - function_length_check( - function_name, - current_types.len(), - target_types.len(), - )?; - - // Aim to keep this logic as SIMPLE as possible! - // Make sure the corresponding test is covered - // If this function becomes COMPLEX, create another new signature! - fn can_coerce_to( - function_name: &str, - current_type: &DataType, - target_type_class: &TypeSignatureClass, - ) -> Result { - let logical_type: NativeType = current_type.into(); - - match target_type_class { - TypeSignatureClass::Native(native_type) => { - let target_type = native_type.native(); - if &logical_type == target_type { - return target_type.default_cast_for(current_type); - } - - if logical_type == NativeType::Null { - return target_type.default_cast_for(current_type); - } - - if target_type.is_integer() && logical_type.is_integer() { - return target_type.default_cast_for(current_type); - } - - internal_err!( - "Function '{function_name}' expects {target_type_class} but received {current_type}" - ) - } - // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - TypeSignatureClass::Timestamp - if logical_type == NativeType::String => - { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Date if logical_type.is_date() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Time if logical_type.is_time() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Interval if logical_type.is_interval() => { - Ok(current_type.to_owned()) - } - TypeSignatureClass::Duration if logical_type.is_duration() => { - Ok(current_type.to_owned()) - } - _ => { - not_impl_err!("Function '{function_name}' got logical_type: {logical_type} with target_type_class: {target_type_class}") - } - } - } - - let mut new_types = Vec::with_capacity(current_types.len()); - for (current_type, target_type_class) in - current_types.iter().zip(target_types.iter()) - { - let target_type = can_coerce_to(function_name, current_type, target_type_class)?; - new_types.push(target_type); - } - - vec![new_types] - } TypeSignature::Uniform(number, valid_types) => { if *number == 0 { return plan_err!("The function '{function_name}' expected at least one argument"); diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index b440d458bcaa..202beb48515e 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -95,7 +95,7 @@ impl DatePartFunc { Self { signature: Signature::one_of( vec![ - TypeSignature::CoercibleV2(vec![ + TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), allowed_casts: vec![], @@ -108,7 +108,7 @@ impl DatePartFunc { )], }, ]), - TypeSignature::CoercibleV2(vec![ + TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), allowed_casts: vec![], @@ -118,7 +118,7 @@ impl DatePartFunc { allowed_casts: vec![], }, ]), - TypeSignature::CoercibleV2(vec![ + TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), allowed_casts: vec![], @@ -128,7 +128,7 @@ impl DatePartFunc { allowed_casts: vec![], }, ]), - TypeSignature::CoercibleV2(vec![ + TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), allowed_casts: vec![], @@ -138,7 +138,7 @@ impl DatePartFunc { allowed_casts: vec![], }, ]), - TypeSignature::CoercibleV2(vec![ + TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), allowed_casts: vec![], From 7a78a6d552abb49a6bd454210e25d0aaf91c4c6d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 20:27:37 +0800 Subject: [PATCH 11/18] fmt --- datafusion/expr-common/src/signature.rs | 2 +- datafusion/expr/src/type_coercion/functions.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 3b78d3942219..33c1d51f98d8 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -131,7 +131,7 @@ pub enum TypeSignature { /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. - /// + /// /// [`Coercion`] contains not only the desired type but also the allowed casts. /// For example, if you expect a function has string type, but you also allow it to be casted from binary type. /// diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 14553af192dc..7d43d6e3ffe2 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,10 +23,8 @@ use arrow::{ }; use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, - types::NativeType, - utils::list_ndims, - Result, + exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, + utils::list_ndims, Result, }; use datafusion_expr_common::{ signature::{ From 2f8c2ada280ffb191a2996652239ff3794121f7a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 7 Feb 2025 20:32:16 +0800 Subject: [PATCH 12/18] rename --- datafusion/expr-common/src/signature.rs | 6 +++--- datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/repeat.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 33c1d51f98d8..cb3100c4f404 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -353,8 +353,8 @@ impl TypeSignature { TypeSignature::Comparable(num) => { vec![format!("Comparable({num})")] } - TypeSignature::Coercible(param_types) => { - vec![Self::join_types(param_types, ", ")] + TypeSignature::Coercible(coercions) => { + vec![Self::join_types(coercions, ", ")] } TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] @@ -637,7 +637,7 @@ impl Signature { } /// Target coerce types in order - pub fn coercible_v2(target_types: Vec, volatility: Volatility) -> Self { + pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), volatility, diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 17dff95580da..912fcff17cd5 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -63,7 +63,7 @@ impl Default for AsciiFunc { impl AsciiFunc { pub fn new() -> Self { Self { - signature: Signature::coercible_v2( + signature: Signature::coercible( vec![Coercion { desired_type: TypeSignatureClass::Native(logical_string()), allowed_casts: vec![TypeSignatureClass::Native(logical_binary())], diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index e7d86536894e..042c5e2e5bfa 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -65,7 +65,7 @@ impl Default for RepeatFunc { impl RepeatFunc { pub fn new() -> Self { Self { - signature: Signature::coercible_v2( + signature: Signature::coercible( vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), From 80549152fceec2c31b5b38b509e085a7d5a692fb Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 9 Feb 2025 16:22:35 +0800 Subject: [PATCH 13/18] rename --- Cargo.lock | 2 +- datafusion/expr-common/src/signature.rs | 16 +++++++-------- .../expr/src/type_coercion/functions.rs | 2 +- .../functions/src/datetime/date_part.rs | 20 +++++++++---------- datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/repeat.rs | 4 ++-- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8b20baefd011..36f56b01cad5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1990,7 +1990,7 @@ version = "45.0.0" dependencies = [ "arrow", "datafusion-common", - "indexmap", + "indexmap 2.7.1", "itertools 0.14.0", "paste", ] diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index cb3100c4f404..cd0475883830 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -429,7 +429,7 @@ impl TypeSignature { .into_iter() .collect(); let allowed_casts: Vec = c - .allowed_casts + .allowed_source_types .iter() .flat_map(get_possible_types_from_signature_classes) .collect(); @@ -524,17 +524,17 @@ fn get_data_types(native_type: &NativeType) -> Vec { #[derive(Debug, Clone, Eq, PartialOrd)] pub struct Coercion { pub desired_type: TypeSignatureClass, - pub allowed_casts: Vec, + pub allowed_source_types: Vec, } impl Display for Coercion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Coercion({}", self.desired_type)?; - if !self.allowed_casts.is_empty() { + if !self.allowed_source_types.is_empty() { write!( f, ", allowed_casts=[{}]", - self.allowed_casts + self.allowed_source_types .iter() .map(|cast| cast.to_string()) .join(", ") @@ -548,14 +548,14 @@ impl Display for Coercion { impl PartialEq for Coercion { fn eq(&self, other: &Self) -> bool { self.desired_type == other.desired_type - && self.allowed_casts == other.allowed_casts + && self.allowed_source_types == other.allowed_source_types } } impl Hash for Coercion { fn hash(&self, state: &mut H) { self.desired_type.hash(state); - self.allowed_casts.hash(state); + self.allowed_source_types.hash(state); } } @@ -838,11 +838,11 @@ mod tests { let type_signature = TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, Coercion { desired_type: TypeSignatureClass::Native(logical_int64()), - allowed_casts: vec![], + allowed_source_types: vec![], }, ]); let possible_types = type_signature.get_possible_types(); diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index bf783c09fa0b..bf6c160075d4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -630,7 +630,7 @@ fn get_valid_types( } if is_matched_type(¶m.desired_type, ¤t_logical_type) || param - .allowed_casts + .allowed_source_types .iter() .any(|t| is_matched_type(t, ¤t_logical_type)) { diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 202beb48515e..ef06f8ac6dd1 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -98,12 +98,12 @@ impl DatePartFunc { TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, Coercion { desired_type: TypeSignatureClass::Timestamp, // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - allowed_casts: vec![TypeSignatureClass::Native( + allowed_source_types: vec![TypeSignatureClass::Native( logical_string(), )], }, @@ -111,41 +111,41 @@ impl DatePartFunc { TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, Coercion { desired_type: TypeSignatureClass::Date, - allowed_casts: vec![], + allowed_source_types: vec![], }, ]), TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, Coercion { desired_type: TypeSignatureClass::Time, - allowed_casts: vec![], + allowed_source_types: vec![], }, ]), TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, Coercion { desired_type: TypeSignatureClass::Interval, - allowed_casts: vec![], + allowed_source_types: vec![], }, ]), TypeSignature::Coercible(vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, Coercion { desired_type: TypeSignatureClass::Duration, - allowed_casts: vec![], + allowed_source_types: vec![], }, ]), ], diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 912fcff17cd5..f1763a7e0870 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -66,7 +66,7 @@ impl AsciiFunc { signature: Signature::coercible( vec![Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![TypeSignatureClass::Native(logical_binary())], + allowed_source_types: vec![TypeSignatureClass::Native(logical_binary())], }], Volatility::Immutable, ), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 042c5e2e5bfa..0317ec7c1d72 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -69,12 +69,12 @@ impl RepeatFunc { vec![ Coercion { desired_type: TypeSignatureClass::Native(logical_string()), - allowed_casts: vec![], + allowed_source_types: vec![], }, // Accept all integer types but cast them to i64 Coercion { desired_type: TypeSignatureClass::Native(logical_int64()), - allowed_casts: vec![TypeSignatureClass::Integer], + allowed_source_types: vec![TypeSignatureClass::Integer], }, ], Volatility::Immutable, From 62da381999e44ffd9cc616a27705d8eef9c55ddf Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 9 Feb 2025 18:27:41 +0800 Subject: [PATCH 14/18] refactor --- datafusion/common/src/types/native.rs | 3 + datafusion/expr-common/src/signature.rs | 127 ++++++++++++------ .../expr/src/type_coercion/functions.rs | 25 ++-- .../functions/src/datetime/date_part.rs | 57 +++----- datafusion/functions/src/string/ascii.rs | 11 +- datafusion/functions/src/string/repeat.rs | 16 +-- 6 files changed, 133 insertions(+), 106 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index c5f180a15035..5453bf50788f 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -226,6 +226,9 @@ impl LogicalType for NativeType { (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), + (Self::Date, origin) if matches!(origin, Date32 | Date64) => { + origin.to_owned() + } (Self::Date, _) => Date32, (Self::Time(tu), _) => match tu { TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index cd0475883830..f632ae8bb82e 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -24,9 +24,7 @@ use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::types::{ - logical_timestamp, LogicalType, LogicalTypeRef, NativeType, -}; +use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; use datafusion_common::{not_impl_err, Result}; use indexmap::IndexSet; use itertools::Itertools; @@ -213,7 +211,6 @@ impl TypeSignature { #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] pub enum TypeSignatureClass { Timestamp, - Date, Time, Interval, Duration, @@ -245,15 +242,6 @@ impl TypeSignatureClass { TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { Ok(origin_type.to_owned()) } - // This is an existing use case for casting string to timestamp, since we don't have specific unit and timezone from string, - // so we use the default timestamp type with nanosecond precision and no timezone - // TODO: Consider allowing the user to specify the default timestamp type instead of having it predefined in DataFusion when we have more use cases - TypeSignatureClass::Timestamp if logical_type == &NativeType::String => { - logical_timestamp().default_cast_for(origin_type) - } - TypeSignatureClass::Date if logical_type.is_date() => { - Ok(origin_type.to_owned()) - } TypeSignatureClass::Time if logical_type.is_time() => { Ok(origin_type.to_owned()) } @@ -428,12 +416,16 @@ impl TypeSignature { get_possible_types_from_signature_classes(&c.desired_type) .into_iter() .collect(); - let allowed_casts: Vec = c - .allowed_source_types - .iter() - .flat_map(get_possible_types_from_signature_classes) - .collect(); - all_types.extend(allowed_casts); + + if let Some(implicit_coercion) = &c.implicit_coercion { + let allowed_casts: Vec = implicit_coercion + .allowed_source_types + .iter() + .flat_map(get_possible_types_from_signature_classes) + .collect(); + all_types.extend(allowed_casts); + } + all_types.into_iter().collect::>() }) .multi_cartesian_product() @@ -474,9 +466,6 @@ fn get_possible_types_from_signature_classes( DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), ] } - TypeSignatureClass::Date => { - vec![DataType::Date64] - } TypeSignatureClass::Time => { vec![DataType::Time64(TimeUnit::Nanosecond)] } @@ -524,21 +513,50 @@ fn get_data_types(native_type: &NativeType) -> Vec { #[derive(Debug, Clone, Eq, PartialOrd)] pub struct Coercion { pub desired_type: TypeSignatureClass, - pub allowed_source_types: Vec, + implicit_coercion: Option, +} + +impl Coercion { + pub fn new(desired_type: TypeSignatureClass) -> Self { + Self { + desired_type, + implicit_coercion: None, + } + } + + pub fn new_with_implicit_coercion( + desired_type: TypeSignatureClass, + allowed_source_types: Vec, + default_type: NativeType, + ) -> Self { + Self { + desired_type, + implicit_coercion: Some(ImplicitCoercion { + allowed_source_types, + default_casted_type: default_type, + }), + } + } + + pub fn allowed_source_types(&self) -> &[TypeSignatureClass] { + self.implicit_coercion + .as_ref() + .map(|c| c.allowed_source_types.as_slice()) + .unwrap_or_default() + } + + pub fn default_casted_type(&self) -> Option<&NativeType> { + self.implicit_coercion + .as_ref() + .map(|c| &c.default_casted_type) + } } impl Display for Coercion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Coercion({}", self.desired_type)?; - if !self.allowed_source_types.is_empty() { - write!( - f, - ", allowed_casts=[{}]", - self.allowed_source_types - .iter() - .map(|cast| cast.to_string()) - .join(", ") - ) + if let Some(implicit_coercion) = &self.implicit_coercion { + write!(f, ", implicit_coercion={implicit_coercion}",) } else { write!(f, ")") } @@ -548,14 +566,47 @@ impl Display for Coercion { impl PartialEq for Coercion { fn eq(&self, other: &Self) -> bool { self.desired_type == other.desired_type - && self.allowed_source_types == other.allowed_source_types + && self.implicit_coercion == other.implicit_coercion } } impl Hash for Coercion { fn hash(&self, state: &mut H) { self.desired_type.hash(state); + self.implicit_coercion.hash(state); + } +} + +#[derive(Debug, Clone, Eq, PartialOrd)] +pub struct ImplicitCoercion { + pub allowed_source_types: Vec, + /// For types like Timestamp, there are multiple possible timeunit and timezone from a given TypeSignatureClass + /// We need to specify the default type to be used for coercion if we cast from other types via `allowed_source_types` + /// Other types like Int64, you don't need to specify this field since there is only one possible type. + pub default_casted_type: NativeType, +} + +impl Display for ImplicitCoercion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ImplicitCoercion({:?}, default_type={:?})", + self.allowed_source_types, self.default_casted_type + ) + } +} + +impl PartialEq for ImplicitCoercion { + fn eq(&self, other: &Self) -> bool { + self.allowed_source_types == other.allowed_source_types + && self.default_casted_type == other.default_casted_type + } +} + +impl Hash for ImplicitCoercion { + fn hash(&self, state: &mut H) { self.allowed_source_types.hash(state); + self.default_casted_type.hash(state); } } @@ -836,14 +887,8 @@ mod tests { ); let type_signature = TypeSignature::Coercible(vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, - Coercion { - desired_type: TypeSignatureClass::Native(logical_int64()), - allowed_source_types: vec![], - }, + Coercion::new(TypeSignatureClass::Native(logical_string())), + Coercion::new(TypeSignatureClass::Native(logical_int64())), ]); let possible_types = type_signature.get_possible_types(); assert_eq!( diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index bf6c160075d4..a243819955e4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,11 +21,11 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result, }; +use datafusion_common::{types::LogicalType, utils::coerced_fixed_size_list_to_list}; use datafusion_expr_common::{ signature::{ ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, @@ -616,7 +616,7 @@ fn get_valid_types( TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { true } - TypeSignatureClass::Date if logical_type.is_date() => true, + // TypeSignatureClass::Date if logical_type.is_date() => true, TypeSignatureClass::Time if logical_type.is_time() => true, TypeSignatureClass::Interval if logical_type.is_interval() => { true @@ -629,14 +629,21 @@ fn get_valid_types( } } - if is_matched_type(¶m.desired_type, ¤t_logical_type) || param - .allowed_source_types - .iter() - .any(|t| is_matched_type(t, ¤t_logical_type)) - { + if is_matched_type(¶m.desired_type, ¤t_logical_type) { let casted_type = param - .desired_type - .default_casted_type(¤t_logical_type, current_type)?; + .desired_type + .default_casted_type(¤t_logical_type, current_type)?; + new_types.push(casted_type); + } else if param + .allowed_source_types() + .iter() + .any(|t| is_matched_type(t, ¤t_logical_type)) { + + if param.default_casted_type().is_none() { + return exec_err!("This shouldn't be None"); + } + let default_type = param.default_casted_type().unwrap(); + let casted_type = default_type.default_cast_for(current_type)?; new_types.push(casted_type); } else { return internal_err!( diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index ef06f8ac6dd1..cd16ebc94815 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -27,6 +27,7 @@ use arrow::datatypes::DataType::{ }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::types::{logical_date, NativeType}; use crate::utils::take_function_args; use datafusion_common::not_impl_err; @@ -96,57 +97,29 @@ impl DatePartFunc { signature: Signature::one_of( vec![ TypeSignature::Coercible(vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, - Coercion { - desired_type: TypeSignatureClass::Timestamp, + Coercion::new(TypeSignatureClass::Native(logical_string())), + Coercion::new_with_implicit_coercion( + TypeSignatureClass::Timestamp, // Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp - allowed_source_types: vec![TypeSignatureClass::Native( - logical_string(), - )], - }, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(Nanosecond, None), + ), ]), TypeSignature::Coercible(vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, - Coercion { - desired_type: TypeSignatureClass::Date, - allowed_source_types: vec![], - }, + Coercion::new(TypeSignatureClass::Native(logical_string())), + Coercion::new(TypeSignatureClass::Native(logical_date())), ]), TypeSignature::Coercible(vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, - Coercion { - desired_type: TypeSignatureClass::Time, - allowed_source_types: vec![], - }, + Coercion::new(TypeSignatureClass::Native(logical_string())), + Coercion::new(TypeSignatureClass::Time), ]), TypeSignature::Coercible(vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, - Coercion { - desired_type: TypeSignatureClass::Interval, - allowed_source_types: vec![], - }, + Coercion::new(TypeSignatureClass::Native(logical_string())), + Coercion::new(TypeSignatureClass::Interval), ]), TypeSignature::Coercible(vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, - Coercion { - desired_type: TypeSignatureClass::Duration, - allowed_source_types: vec![], - }, + Coercion::new(TypeSignatureClass::Native(logical_string())), + Coercion::new(TypeSignatureClass::Duration), ]), ], Volatility::Immutable, diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index f1763a7e0870..1a2a87e2d6cf 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -19,7 +19,7 @@ use crate::utils::make_scalar_function; use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::types::{logical_binary, logical_string}; +use datafusion_common::types::{logical_binary, logical_string, NativeType}; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -64,10 +64,11 @@ impl AsciiFunc { pub fn new() -> Self { Self { signature: Signature::coercible( - vec![Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![TypeSignatureClass::Native(logical_binary())], - }], + vec![Coercion::new_with_implicit_coercion( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + )], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 0317ec7c1d72..85b7b5c59836 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -26,7 +26,7 @@ use arrow::array::{ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; -use datafusion_common::types::{logical_int64, logical_string}; +use datafusion_common::types::{logical_int64, logical_string, NativeType}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; @@ -67,15 +67,13 @@ impl RepeatFunc { Self { signature: Signature::coercible( vec![ - Coercion { - desired_type: TypeSignatureClass::Native(logical_string()), - allowed_source_types: vec![], - }, + Coercion::new(TypeSignatureClass::Native(logical_string())), // Accept all integer types but cast them to i64 - Coercion { - desired_type: TypeSignatureClass::Native(logical_int64()), - allowed_source_types: vec![TypeSignatureClass::Integer], - }, + Coercion::new_with_implicit_coercion( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), ], Volatility::Immutable, ), From 231d75b7504d8d5ba7eb758d044be1e5e00b4135 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 9 Feb 2025 18:52:13 +0800 Subject: [PATCH 15/18] make default_casted_type private --- datafusion/common/src/types/native.rs | 6 +++ datafusion/expr-common/src/signature.rs | 52 ++++++------------- .../expr/src/type_coercion/functions.rs | 48 +++++++++++++---- 3 files changed, 58 insertions(+), 48 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 5453bf50788f..f8ec0795fb7f 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -198,6 +198,11 @@ impl LogicalType for NativeType { TypeSignature::Native(self) } + /// Returns the default casted type for the given arrow type + /// + /// For types like String or Date, multiple arrow types mapped to the same logical type + /// If the given arrow type is one of them, we return the same type + /// Otherwise, we define the default casted type for the given arrow type fn default_cast_for(&self, origin: &DataType) -> Result { use DataType::*; @@ -226,6 +231,7 @@ impl LogicalType for NativeType { (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), + // If given type is Date, return the same type (Self::Date, origin) if matches!(origin, Date32 | Date64) => { origin.to_owned() } diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index f632ae8bb82e..fcc61e613bac 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -24,8 +24,7 @@ use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::types::{LogicalTypeRef, NativeType}; use indexmap::IndexSet; use itertools::Itertools; @@ -226,36 +225,6 @@ impl Display for TypeSignatureClass { } } -impl TypeSignatureClass { - /// Returns the default cast type for the given `TypeSignatureClass`. - /// Be cautious to avoid adding specialized logic here, as this function is public and intended for general use. - pub fn default_casted_type( - &self, - logical_type: &NativeType, - origin_type: &DataType, - ) -> Result { - match self { - TypeSignatureClass::Native(logical_type) => { - logical_type.native().default_cast_for(origin_type) - } - // If the given type is already a timestamp, we don't change the unit and timezone - TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { - Ok(origin_type.to_owned()) - } - TypeSignatureClass::Time if logical_type.is_time() => { - Ok(origin_type.to_owned()) - } - TypeSignatureClass::Interval if logical_type.is_interval() => { - Ok(origin_type.to_owned()) - } - TypeSignatureClass::Duration if logical_type.is_duration() => { - Ok(origin_type.to_owned()) - } - _ => not_impl_err!("Other cases are not implemented yet"), - } - } -} - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions @@ -396,7 +365,12 @@ impl TypeSignature { } } - /// get all possible types for the given `TypeSignature` + /// This function is used specifically internally for `information_schema` + /// We suggest not to rely on this function + /// + /// Get all possible types for `information_schema` from the given `TypeSignature` + // + // TODO: Make this function private pub fn get_possible_types(&self) -> Vec> { match self { TypeSignature::Exact(types) => vec![types.clone()], @@ -524,16 +498,20 @@ impl Coercion { } } + /// Create a new coercion with implicit coercion rules. + /// + /// `allowed_source_types` defines the possible types that can be coerced to `desired_type`. + /// `default_casted_type` is the default type to be used for coercion if we cast from other types via `allowed_source_types`. pub fn new_with_implicit_coercion( desired_type: TypeSignatureClass, allowed_source_types: Vec, - default_type: NativeType, + default_casted_type: NativeType, ) -> Self { Self { desired_type, implicit_coercion: Some(ImplicitCoercion { allowed_source_types, - default_casted_type: default_type, + default_casted_type, }), } } @@ -579,11 +557,11 @@ impl Hash for Coercion { #[derive(Debug, Clone, Eq, PartialOrd)] pub struct ImplicitCoercion { - pub allowed_source_types: Vec, + allowed_source_types: Vec, /// For types like Timestamp, there are multiple possible timeunit and timezone from a given TypeSignatureClass /// We need to specify the default type to be used for coercion if we cast from other types via `allowed_source_types` /// Other types like Int64, you don't need to specify this field since there is only one possible type. - pub default_casted_type: NativeType, + default_casted_type: NativeType, } impl Display for ImplicitCoercion { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index a243819955e4..647c124005ea 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -22,8 +22,8 @@ use arrow::{ datatypes::{DataType, TimeUnit}, }; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, - utils::list_ndims, Result, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, + types::NativeType, utils::list_ndims, Result, }; use datafusion_common::{types::LogicalType, utils::coerced_fixed_size_list_to_list}; use datafusion_expr_common::{ @@ -629,21 +629,47 @@ fn get_valid_types( } } + fn default_casted_type( + signature_class: &TypeSignatureClass, + logical_type: &NativeType, + origin_type: &DataType, + ) -> Result { + match signature_class { + TypeSignatureClass::Native(logical_type) => { + logical_type.native().default_cast_for(origin_type) + } + // If the given type is already a timestamp, we don't change the unit and timezone + TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Time if logical_type.is_time() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Interval if logical_type.is_interval() => { + Ok(origin_type.to_owned()) + } + TypeSignatureClass::Duration if logical_type.is_duration() => { + Ok(origin_type.to_owned()) + } + _ => not_impl_err!("Other cases are not implemented yet"), + } + } + if is_matched_type(¶m.desired_type, ¤t_logical_type) { - let casted_type = param - .desired_type - .default_casted_type(¤t_logical_type, current_type)?; + let casted_type = default_casted_type( + ¶m.desired_type, + ¤t_logical_type, + current_type, + )?; + new_types.push(casted_type); } else if param .allowed_source_types() .iter() .any(|t| is_matched_type(t, ¤t_logical_type)) { - - if param.default_casted_type().is_none() { - return exec_err!("This shouldn't be None"); - } - let default_type = param.default_casted_type().unwrap(); - let casted_type = default_type.default_cast_for(current_type)?; + // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap + let default_casted_type = param.default_casted_type().unwrap(); + let casted_type = default_casted_type.default_cast_for(current_type)?; new_types.push(casted_type); } else { return internal_err!( From 44250fc415d4da31f0d53d57612f064565055a91 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 9 Feb 2025 18:54:52 +0800 Subject: [PATCH 16/18] cleanup --- datafusion/common/src/types/builtin.rs | 12 +----------- datafusion/expr/src/type_coercion/functions.rs | 1 - 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index 10e841147e4a..d6ab6d582167 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::TimeUnit; - use crate::types::{LogicalTypeRef, NativeType}; use std::sync::{Arc, LazyLock}; @@ -48,12 +46,4 @@ singleton!(LOGICAL_FLOAT32, logical_float32, Float32); singleton!(LOGICAL_FLOAT64, logical_float64, Float64); singleton!(LOGICAL_DATE, logical_date, Date); singleton!(LOGICAL_BINARY, logical_binary, Binary); -singleton!(LOGICAL_STRING, logical_string, String); - -// TODO: Extend macro -// TODO: Should we use LOGICAL_TIMESTAMP_NANO to distinguish unit and timzeone? -static LOGICAL_TIMESTAMP: LazyLock = - LazyLock::new(|| Arc::new(NativeType::Timestamp(TimeUnit::Nanosecond, None))); -pub fn logical_timestamp() -> LogicalTypeRef { - Arc::clone(&LOGICAL_TIMESTAMP) -} +singleton!(LOGICAL_STRING, logical_string, String); \ No newline at end of file diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 647c124005ea..b64383481c03 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -616,7 +616,6 @@ fn get_valid_types( TypeSignatureClass::Timestamp if logical_type.is_timestamp() => { true } - // TypeSignatureClass::Date if logical_type.is_date() => true, TypeSignatureClass::Time if logical_type.is_time() => true, TypeSignatureClass::Interval if logical_type.is_interval() => { true From 2f6b5d65baf09bda49cd3cd179e9008ad1b658dd Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 9 Feb 2025 18:55:04 +0800 Subject: [PATCH 17/18] fmt --- datafusion/common/src/types/builtin.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index d6ab6d582167..ec69db790377 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -46,4 +46,4 @@ singleton!(LOGICAL_FLOAT32, logical_float32, Float32); singleton!(LOGICAL_FLOAT64, logical_float64, Float64); singleton!(LOGICAL_DATE, logical_date, Date); singleton!(LOGICAL_BINARY, logical_binary, Binary); -singleton!(LOGICAL_STRING, logical_string, String); \ No newline at end of file +singleton!(LOGICAL_STRING, logical_string, String); From 5387e5ed73ab183b3d2d4278f35332244535df08 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 9 Feb 2025 19:58:18 +0800 Subject: [PATCH 18/18] integer --- datafusion/expr/src/type_coercion/functions.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b64383481c03..c37b9a13d475 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -22,8 +22,8 @@ use arrow::{ datatypes::{DataType, TimeUnit}, }; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - types::NativeType, utils::list_ndims, Result, + exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, + utils::list_ndims, Result, }; use datafusion_common::{types::LogicalType, utils::coerced_fixed_size_list_to_list}; use datafusion_expr_common::{ @@ -650,7 +650,10 @@ fn get_valid_types( TypeSignatureClass::Duration if logical_type.is_duration() => { Ok(origin_type.to_owned()) } - _ => not_impl_err!("Other cases are not implemented yet"), + TypeSignatureClass::Integer if logical_type.is_integer() => { + Ok(origin_type.to_owned()) + } + _ => internal_err!("May miss the matching logic in `is_matched_type`"), } }