diff --git a/common/datavalues/src/columns/nullable/mod.rs b/common/datavalues/src/columns/nullable/mod.rs index 360af6d6441f..fd87e7f1a39b 100644 --- a/common/datavalues/src/columns/nullable/mod.rs +++ b/common/datavalues/src/columns/nullable/mod.rs @@ -87,7 +87,7 @@ impl Column for NullableColumn { fn data_type(&self) -> DataTypePtr { let nest = self.column.data_type(); - Arc::new(NullableType::create(nest)) + NullableType::arc(nest) } fn column_type_name(&self) -> String { @@ -111,7 +111,7 @@ impl Column for NullableColumn { } fn validity(&self) -> (bool, Option<&Bitmap>) { - (false, Some(&self.validity)) + (self.only_null(), Some(&self.validity)) } fn memory_size(&self) -> usize { diff --git a/common/datavalues/src/types/data_type.rs b/common/datavalues/src/types/data_type.rs index fc50272d33d5..746e38f5af71 100644 --- a/common/datavalues/src/types/data_type.rs +++ b/common/datavalues/src/types/data_type.rs @@ -179,7 +179,7 @@ pub fn from_arrow_field(f: &ArrowField) -> DataTypePtr { let is_nullable = f.is_nullable; if is_nullable && ty.can_inside_nullable() { - Arc::new(NullableType::create(ty)) + NullableType::arc(ty) } else { ty } @@ -209,7 +209,7 @@ pub fn wrap_nullable(data_type: &DataTypePtr) -> DataTypePtr { if !data_type.can_inside_nullable() { return data_type.clone(); } - Arc::new(NullableType::create(data_type.clone())) + NullableType::arc(data_type.clone()) } pub fn remove_nullable(data_type: &DataTypePtr) -> DataTypePtr { diff --git a/common/datavalues/src/types/type_factory.rs b/common/datavalues/src/types/type_factory.rs index 362f5c2abb38..7a7b1ad6d760 100644 --- a/common/datavalues/src/types/type_factory.rs +++ b/common/datavalues/src/types/type_factory.rs @@ -114,7 +114,7 @@ impl TypeFactory { let mut nulls = HashMap::new(); for (k, v) in self.case_insensitive_types.iter() { if v.can_inside_nullable() { - let data_type: DataTypePtr = Arc::new(NullableType::create(v.clone())); + let data_type: DataTypePtr = NullableType::arc(v.clone()); nulls.insert( format!("Nullable({})", k).to_ascii_lowercase(), data_type.clone(), diff --git a/common/datavalues/src/types/type_nullable.rs b/common/datavalues/src/types/type_nullable.rs index ffe9f2f2a91e..aac96c5d4633 100644 --- a/common/datavalues/src/types/type_nullable.rs +++ b/common/datavalues/src/types/type_nullable.rs @@ -31,6 +31,10 @@ pub struct NullableType { } impl NullableType { + pub fn arc(inner: DataTypePtr) -> DataTypePtr { + Arc::new(Self::create(inner)) + } + pub fn create(inner: DataTypePtr) -> Self { debug_assert!(inner.can_inside_nullable()); NullableType { diff --git a/common/datavalues/tests/it/types/create_column.rs b/common/datavalues/tests/it/types/create_column.rs index 4069f5fc2098..c0fe902e4786 100644 --- a/common/datavalues/tests/it/types/create_column.rs +++ b/common/datavalues/tests/it/types/create_column.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use common_datavalues::prelude::*; use common_exception::Result; use pretty_assertions::assert_eq; @@ -82,7 +80,7 @@ fn test_create_constant() -> Result<()> { }, Test { name: "nullable_i32", - data_type: Arc::new(NullableType::create(Int32Type::arc())), + data_type: NullableType::arc(Int32Type::arc()), value: DataValue::Null, size: 2, column_expected: Series::from_data(&[None, None, Some(1i32)][0..2]), diff --git a/common/functions/src/scalars/arithmetics/arithmetic.rs b/common/functions/src/scalars/arithmetics/arithmetic.rs index d87d358bff3c..c29d5070279f 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic.rs @@ -26,18 +26,18 @@ pub struct ArithmeticFunction; impl ArithmeticFunction { pub fn register(factory: &mut FunctionFactory) { - factory.register_typed("negate", ArithmeticNegateFunction::desc()); - factory.register_typed("+", ArithmeticPlusFunction::desc()); - factory.register_typed("plus", ArithmeticPlusFunction::desc()); - factory.register_typed("-", ArithmeticMinusFunction::desc()); - factory.register_typed("minus", ArithmeticMinusFunction::desc()); - factory.register_typed("*", ArithmeticMulFunction::desc()); - factory.register_typed("multiply", ArithmeticMulFunction::desc()); - factory.register_typed("/", ArithmeticDivFunction::desc()); - factory.register_typed("divide", ArithmeticDivFunction::desc()); - factory.register_typed("div", ArithmeticIntDivFunction::desc()); - factory.register_typed("%", ArithmeticModuloFunction::desc()); - factory.register_typed("modulo", ArithmeticModuloFunction::desc()); - factory.register_typed("mod", ArithmeticModuloFunction::desc()); + factory.register("negate", ArithmeticNegateFunction::desc()); + factory.register("+", ArithmeticPlusFunction::desc()); + factory.register("plus", ArithmeticPlusFunction::desc()); + factory.register("-", ArithmeticMinusFunction::desc()); + factory.register("minus", ArithmeticMinusFunction::desc()); + factory.register("*", ArithmeticMulFunction::desc()); + factory.register("multiply", ArithmeticMulFunction::desc()); + factory.register("/", ArithmeticDivFunction::desc()); + factory.register("divide", ArithmeticDivFunction::desc()); + factory.register("div", ArithmeticIntDivFunction::desc()); + factory.register("%", ArithmeticModuloFunction::desc()); + factory.register("modulo", ArithmeticModuloFunction::desc()); + factory.register("mod", ArithmeticModuloFunction::desc()); } } diff --git a/common/functions/src/scalars/arithmetics/arithmetic_div.rs b/common/functions/src/scalars/arithmetics/arithmetic_div.rs index caa74464fd11..6849a7fadc92 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_div.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_div.rs @@ -22,9 +22,9 @@ use super::arithmetic_mul::arithmetic_mul_div_monotonicity; use crate::scalars::BinaryArithmeticFunction; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; use crate::scalars::Monotonicity; -use crate::scalars::TypedFunctionDescription; #[inline] fn div_scalar(l: impl AsPrimitive, r: impl AsPrimitive, _ctx: &mut EvalContext) -> f64 { @@ -53,8 +53,8 @@ impl ArithmeticDivFunction { arithmetic_mul_div_monotonicity(args, DataValueBinaryOperator::Div) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .monotonicity() diff --git a/common/functions/src/scalars/arithmetics/arithmetic_intdiv.rs b/common/functions/src/scalars/arithmetics/arithmetic_intdiv.rs index c4b0160a0a55..6ab51a1bb151 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_intdiv.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_intdiv.rs @@ -22,8 +22,8 @@ use num_traits::AsPrimitive; use crate::scalars::BinaryArithmeticFunction; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; -use crate::scalars::TypedFunctionDescription; #[inline] fn intdiv_scalar(l: impl AsPrimitive, r: impl AsPrimitive, ctx: &mut EvalContext) -> O @@ -58,8 +58,8 @@ impl ArithmeticIntDivFunction { }) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .monotonicity() diff --git a/common/functions/src/scalars/arithmetics/arithmetic_minus.rs b/common/functions/src/scalars/arithmetics/arithmetic_minus.rs index 56794cb2acfb..9118f565d6e7 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_minus.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_minus.rs @@ -26,10 +26,10 @@ use num_traits::WrappingSub; use crate::scalars::BinaryArithmeticFunction; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFactory; use crate::scalars::FunctionFeatures; use crate::scalars::Monotonicity; -use crate::scalars::TypedFunctionDescription; #[inline] fn sub_scalar(l: impl AsPrimitive, r: impl AsPrimitive, _ctx: &mut EvalContext) -> O @@ -118,8 +118,8 @@ impl ArithmeticMinusFunction { }) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .monotonicity() diff --git a/common/functions/src/scalars/arithmetics/arithmetic_modulo.rs b/common/functions/src/scalars/arithmetics/arithmetic_modulo.rs index 74c2affa71c8..b9737ab7f923 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_modulo.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_modulo.rs @@ -25,8 +25,9 @@ use num_traits::AsPrimitive; use super::utils::rem_scalar; use crate::scalars::Function; +use crate::scalars::FunctionContext; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; -use crate::scalars::TypedFunctionDescription; pub struct ArithmeticModuloFunction; @@ -44,8 +45,8 @@ impl ArithmeticModuloFunction { }) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)) + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)) .features(FunctionFeatures::default().deterministic().num_arguments(2)) } } @@ -73,8 +74,8 @@ where "ModuloFunctionImpl" } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(O::to_data_type()) + fn return_type(&self) -> DataTypePtr { + O::to_data_type() } fn eval( @@ -160,4 +161,3 @@ where write!(f, "div") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/arithmetics/arithmetic_mul.rs b/common/functions/src/scalars/arithmetics/arithmetic_mul.rs index c17697a72f01..fa7e617ae049 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_mul.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_mul.rs @@ -24,9 +24,9 @@ use num_traits::WrappingMul; use crate::scalars::BinaryArithmeticFunction; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; use crate::scalars::Monotonicity; -use crate::scalars::TypedFunctionDescription; #[inline] fn mul_scalar(l: impl AsPrimitive, r: impl AsPrimitive, _ctx: &mut EvalContext) -> O @@ -79,8 +79,8 @@ impl ArithmeticMulFunction { }) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .monotonicity() diff --git a/common/functions/src/scalars/arithmetics/arithmetic_negate.rs b/common/functions/src/scalars/arithmetics/arithmetic_negate.rs index d8682e9d4be4..a3e3914f74a9 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_negate.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_negate.rs @@ -23,9 +23,9 @@ use num_traits::WrappingNeg; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; use crate::scalars::Monotonicity; -use crate::scalars::TypedFunctionDescription; use crate::scalars::UnaryArithmeticFunction; fn neg(l: impl AsPrimitive, _ctx: &mut EvalContext) -> O @@ -79,8 +79,8 @@ impl ArithmeticNegateFunction { }) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .monotonicity() diff --git a/common/functions/src/scalars/arithmetics/arithmetic_plus.rs b/common/functions/src/scalars/arithmetics/arithmetic_plus.rs index 103a93e1ffca..aabf42651dee 100644 --- a/common/functions/src/scalars/arithmetics/arithmetic_plus.rs +++ b/common/functions/src/scalars/arithmetics/arithmetic_plus.rs @@ -26,10 +26,10 @@ use num_traits::WrappingAdd; use crate::scalars::BinaryArithmeticFunction; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFactory; use crate::scalars::FunctionFeatures; use crate::scalars::Monotonicity; -use crate::scalars::TypedFunctionDescription; #[inline] fn add_scalar(l: impl AsPrimitive, r: impl AsPrimitive, _ctx: &mut EvalContext) -> O @@ -125,8 +125,8 @@ impl ArithmeticPlusFunction { }) } - pub fn desc() -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .monotonicity() diff --git a/common/functions/src/scalars/arithmetics/binary_arithmetic.rs b/common/functions/src/scalars/arithmetics/binary_arithmetic.rs index aa9fa5c76fbf..bd5874cddb77 100644 --- a/common/functions/src/scalars/arithmetics/binary_arithmetic.rs +++ b/common/functions/src/scalars/arithmetics/binary_arithmetic.rs @@ -70,8 +70,8 @@ where "BinaryArithmeticFunction" } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(self.result_type.clone()) + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( diff --git a/common/functions/src/scalars/arithmetics/unary_arithmetic.rs b/common/functions/src/scalars/arithmetics/unary_arithmetic.rs index dc1a7f262220..ec92ad42b542 100644 --- a/common/functions/src/scalars/arithmetics/unary_arithmetic.rs +++ b/common/functions/src/scalars/arithmetics/unary_arithmetic.rs @@ -65,8 +65,8 @@ where "UnaryArithmeticFunction" } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(self.result_type.clone()) + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( diff --git a/common/functions/src/scalars/comparisons/comparison.rs b/common/functions/src/scalars/comparisons/comparison.rs index 6285ebb73d6f..6cdd57dcaa1d 100644 --- a/common/functions/src/scalars/comparisons/comparison.rs +++ b/common/functions/src/scalars/comparisons/comparison.rs @@ -44,9 +44,9 @@ use crate::scalars::ComparisonRegexpFunction; use crate::scalars::EvalContext; use crate::scalars::Function; use crate::scalars::FunctionContext; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFactory; use crate::scalars::FunctionFeatures; -use crate::scalars::TypedFunctionDescription; #[derive(Clone)] pub struct ComparisonFunction { @@ -56,19 +56,19 @@ pub struct ComparisonFunction { impl ComparisonFunction { pub fn register(factory: &mut FunctionFactory) { - factory.register_typed("=", ComparisonEqFunction::desc("<>")); - factory.register_typed("<", ComparisonLtFunction::desc(">=")); - factory.register_typed(">", ComparisonGtFunction::desc("<=")); - factory.register_typed("<=", ComparisonLtEqFunction::desc(">")); - factory.register_typed(">=", ComparisonGtEqFunction::desc("<")); - factory.register_typed("!=", ComparisonNotEqFunction::desc("=")); - factory.register_typed("<>", ComparisonNotEqFunction::desc("=")); - factory.register_typed("like", ComparisonLikeFunction::desc("not like")); - factory.register_typed("not like", ComparisonNotLikeFunction::desc("like")); - factory.register_typed("regexp", ComparisonRegexpFunction::desc("not regexp")); - factory.register_typed("not regexp", ComparisonNotRegexpFunction::desc("regexp")); - factory.register_typed("rlike", ComparisonRegexpFunction::desc("not regexp")); - factory.register_typed("not rlike", ComparisonNotRegexpFunction::desc("regexp")); + factory.register("=", ComparisonEqFunction::desc("<>")); + factory.register("<", ComparisonLtFunction::desc(">=")); + factory.register(">", ComparisonGtFunction::desc("<=")); + factory.register("<=", ComparisonLtEqFunction::desc(">")); + factory.register(">=", ComparisonGtEqFunction::desc("<")); + factory.register("!=", ComparisonNotEqFunction::desc("=")); + factory.register("<>", ComparisonNotEqFunction::desc("=")); + factory.register("like", ComparisonLikeFunction::desc("not like")); + factory.register("not like", ComparisonNotLikeFunction::desc("like")); + factory.register("regexp", ComparisonRegexpFunction::desc("not regexp")); + factory.register("not regexp", ComparisonNotRegexpFunction::desc("regexp")); + factory.register("rlike", ComparisonRegexpFunction::desc("not regexp")); + factory.register("not rlike", ComparisonNotRegexpFunction::desc("regexp")); } pub fn try_create_func( @@ -87,8 +87,8 @@ impl Function for ComparisonFunction { self.display_name.as_str() } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(BooleanType::arc()) + fn return_type(&self) -> DataTypePtr { + BooleanType::arc() } fn eval( @@ -167,8 +167,8 @@ impl ComparisonFunctionCreator { }) } - pub fn desc(negative_name: &str) -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc(negative_name: &str) -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .negative_function(negative_name) @@ -197,8 +197,8 @@ impl StringSearchCreator { ComparisonFunction::try_create_func(display_name, func) } - pub fn desc(negative_name: &str) -> TypedFunctionDescription { - TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features( + pub fn desc(negative_name: &str) -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create_func)).features( FunctionFeatures::default() .deterministic() .negative_function(negative_name) diff --git a/common/functions/src/scalars/conditionals/if.rs b/common/functions/src/scalars/conditionals/if.rs index 5bc4f570a05a..934a627f8eea 100644 --- a/common/functions/src/scalars/conditionals/if.rs +++ b/common/functions/src/scalars/conditionals/if.rs @@ -29,12 +29,17 @@ use crate::scalars::FunctionFeatures; #[derive(Clone, Debug)] pub struct IfFunction { display_name: String, + least_supertype: DataTypePtr, } impl IfFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + let dts = vec![args[1].clone(), args[2].clone()]; + let least_supertype = aggregate_types(dts.as_slice())?; + Ok(Box::new(IfFunction { display_name: display_name.to_string(), + least_supertype, })) } @@ -80,12 +85,8 @@ impl IfFunction { (&columns[1], &columns[0], true) }; - // cast to least super type - let dts = vec![lhs_col.data_type().clone(), rhs_col.data_type().clone()]; - let least_supertype = aggregate_types(dts.as_slice())?; - - let lhs = cast_column_field(lhs_col, &least_supertype)?; - let rhs = cast_column_field(rhs_col, &least_supertype)?; + let lhs = cast_column_field(lhs_col, &self.least_supertype)?; + let rhs = cast_column_field(rhs_col, &self.least_supertype)?; let type_id = remove_nullable(&lhs.data_type()).data_type_id(); @@ -184,13 +185,10 @@ impl IfFunction { let lhs_col = &columns[0]; let rhs_col = &columns[1]; - let dts = vec![lhs_col.data_type().clone(), rhs_col.data_type().clone()]; - let least_supertype = aggregate_types(dts.as_slice())?; - - let lhs = cast_column_field(lhs_col, &least_supertype)?; - let rhs = cast_column_field(rhs_col, &least_supertype)?; + let lhs = cast_column_field(lhs_col, &self.least_supertype)?; + let rhs = cast_column_field(rhs_col, &self.least_supertype)?; - let type_id = remove_nullable(&least_supertype).data_type_id(); + let type_id = remove_nullable(&self.least_supertype).data_type_id(); with_match_scalar_type!(type_id.to_physical_type(), |$T| { let lhs_viewer = $T::try_create_viewer(&lhs)?; @@ -223,14 +221,11 @@ impl IfFunction { let lhs_col = &columns[0]; let rhs_col = &columns[1]; - let dts = vec![lhs_col.data_type().clone(), rhs_col.data_type().clone()]; - let least_supertype = aggregate_types(dts.as_slice())?; - - let lhs = cast_column_field(lhs_col, &least_supertype)?; - let rhs = cast_column_field(rhs_col, &least_supertype)?; + let lhs = cast_column_field(lhs_col, &self.least_supertype)?; + let rhs = cast_column_field(rhs_col, &self.least_supertype)?; - debug_assert!(!least_supertype.is_nullable()); - let type_id = least_supertype.data_type_id(); + debug_assert!(!self.least_supertype.is_nullable()); + let type_id = self.least_supertype.data_type_id(); with_match_scalar_type!(type_id.to_physical_type(), |$T| { let lhs = Series::check_get_scalar::<$T>(&lhs)?; @@ -255,11 +250,8 @@ impl Function for IfFunction { "IfFunction" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - let dts = vec![args[1].clone(), args[2].clone()]; - let least_supertype = aggregate_types(dts.as_slice())?; - - Ok(least_supertype) + fn return_type(&self) -> DataTypePtr { + self.least_supertype.clone() } fn eval( diff --git a/common/functions/src/scalars/conditionals/in_basic.rs b/common/functions/src/scalars/conditionals/in_basic.rs index 2aa218b931c5..7811f11b5058 100644 --- a/common/functions/src/scalars/conditionals/in_basic.rs +++ b/common/functions/src/scalars/conditionals/in_basic.rs @@ -28,11 +28,28 @@ use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; #[derive(Clone)] -pub struct InFunction; +pub struct InFunction { + is_null: bool, +} impl InFunction { - pub fn try_create(_display_name: &str) -> Result> { - Ok(Box::new(InFunction:: {})) + pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result> { + for dt in args { + let type_id = remove_nullable(dt).data_type_id(); + if type_id.is_date_or_date_time() + || type_id.is_interval() + || type_id.is_array() + || type_id.is_struct() + { + return Err(ErrorCode::UnexpectedError(format!( + "{} type is not supported for IN now", + type_id + ))); + } + } + + let is_null = args[0].data_type_id() == TypeID::Null; + Ok(Box::new(InFunction:: { is_null })) } pub fn desc() -> FunctionDescription { @@ -94,25 +111,11 @@ impl Function for InFunction { "InFunction" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for dt in args { - let type_id = remove_nullable(dt).data_type_id(); - if type_id.is_date_or_date_time() - || type_id.is_interval() - || type_id.is_array() - || type_id.is_struct() - { - return Err(ErrorCode::UnexpectedError(format!( - "{} type is not supported for IN now", - type_id - ))); - } - } - let input_dt = remove_nullable(args[0]).data_type_id(); - if input_dt == TypeID::Null { - return Ok(NullType::arc()); + fn return_type(&self) -> DataTypePtr { + if self.is_null { + return NullType::arc(); } - Ok(BooleanType::arc()) + BooleanType::arc() } fn eval( @@ -121,24 +124,7 @@ impl Function for InFunction { input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - for col in columns { - let dt = col.column().data_type(); - let type_id = remove_nullable(&dt).data_type_id(); - if type_id.is_date_or_date_time() - || type_id.is_interval() - || type_id.is_array() - || type_id.is_struct() - { - return Err(ErrorCode::UnexpectedError(format!( - "{} type is not supported for IN now", - type_id - ))); - } - } - - let input_col = &columns[0]; - let input_dt = remove_nullable(input_col.data_type()).data_type_id(); - if input_dt == TypeID::Null { + if self.is_null { let col = NullType::arc().create_constant_column(&DataValue::Null, input_rows)?; return Ok(col); } @@ -147,7 +133,7 @@ impl Function for InFunction { let least_super_dt = aggregate_types(&types)?; let least_super_type_id = remove_nullable(&least_super_dt).data_type_id(); - let input_col = cast_column_field(input_col, &least_super_dt)?; + let input_col = cast_column_field(&columns[0], &least_super_dt)?; match least_super_type_id { TypeID::Boolean => { diff --git a/common/functions/src/scalars/conditionals/is_not_null.rs b/common/functions/src/scalars/conditionals/is_not_null.rs index d2af1323a018..7fad11ea19af 100644 --- a/common/functions/src/scalars/conditionals/is_not_null.rs +++ b/common/functions/src/scalars/conditionals/is_not_null.rs @@ -28,7 +28,10 @@ pub struct IsNotNullFunction { } impl IsNotNullFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + _args: &[&DataTypePtr], + ) -> Result> { Ok(Box::new(IsNotNullFunction { _display_name: "isNotNull".to_string(), })) @@ -51,11 +54,8 @@ impl Function for IsNotNullFunction { "IsNotNullFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(bool::to_data_type()) + fn return_type(&self) -> DataTypePtr { + bool::to_data_type() } fn eval( diff --git a/common/functions/src/scalars/conditionals/is_null.rs b/common/functions/src/scalars/conditionals/is_null.rs index 2af2b2e10cf6..cf64462f3834 100644 --- a/common/functions/src/scalars/conditionals/is_null.rs +++ b/common/functions/src/scalars/conditionals/is_null.rs @@ -28,7 +28,10 @@ pub struct IsNullFunction { } impl IsNullFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + _args: &[&DataTypePtr], + ) -> Result> { Ok(Box::new(IsNullFunction { _display_name: "isNull".to_string(), })) @@ -51,11 +54,8 @@ impl Function for IsNullFunction { "IsNullFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(bool::to_data_type()) + fn return_type(&self) -> DataTypePtr { + bool::to_data_type() } fn eval( diff --git a/common/functions/src/scalars/contexts/current_user.rs b/common/functions/src/scalars/contexts/current_user.rs index e376ba320185..9678589db0c9 100644 --- a/common/functions/src/scalars/contexts/current_user.rs +++ b/common/functions/src/scalars/contexts/current_user.rs @@ -14,6 +14,7 @@ use std::fmt; +use common_datavalues::DataTypePtr; use common_datavalues::StringType; use common_exception::Result; @@ -26,7 +27,7 @@ use crate::scalars::FunctionFeatures; pub struct CurrentUserFunction {} impl CurrentUserFunction { - pub fn try_create(_display_name: &str) -> Result> { + pub fn try_create(_display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(CurrentUserFunction {})) } @@ -44,11 +45,8 @@ impl Function for CurrentUserFunction { "CurrentUserFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( diff --git a/common/functions/src/scalars/contexts/database.rs b/common/functions/src/scalars/contexts/database.rs index 27f36e385a92..f9fb0cdfaff8 100644 --- a/common/functions/src/scalars/contexts/database.rs +++ b/common/functions/src/scalars/contexts/database.rs @@ -14,6 +14,7 @@ use std::fmt; +use common_datavalues::DataTypePtr; use common_datavalues::StringType; use common_exception::Result; @@ -27,7 +28,7 @@ pub struct DatabaseFunction {} // we bind database as first argument in eval impl DatabaseFunction { - pub fn try_create(_display_name: &str) -> Result> { + pub fn try_create(_display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(DatabaseFunction {})) } @@ -45,11 +46,8 @@ impl Function for DatabaseFunction { "DatabaseFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( diff --git a/common/functions/src/scalars/contexts/version.rs b/common/functions/src/scalars/contexts/version.rs index 1cd824a99436..fbffb71b8594 100644 --- a/common/functions/src/scalars/contexts/version.rs +++ b/common/functions/src/scalars/contexts/version.rs @@ -14,6 +14,7 @@ use std::fmt; +use common_datavalues::DataTypePtr; use common_datavalues::StringType; use common_exception::Result; @@ -28,7 +29,7 @@ pub struct VersionFunction { } impl VersionFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(VersionFunction { _display_name: display_name.to_string(), })) @@ -48,11 +49,8 @@ impl Function for VersionFunction { "VersionFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( diff --git a/common/functions/src/scalars/dates/date.rs b/common/functions/src/scalars/dates/date.rs index 1efaf51185c1..5e6beb9851c4 100644 --- a/common/functions/src/scalars/dates/date.rs +++ b/common/functions/src/scalars/dates/date.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use common_exception::Result; - use super::now::NowFunction; use super::number_function::ToMondayFunction; use super::number_function::ToYearFunction; @@ -41,7 +39,6 @@ use super::TodayFunction; use super::TomorrowFunction; use super::YesterdayFunction; use crate::scalars::function_factory::FactoryCreator; -use crate::scalars::Function; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFactory; use crate::scalars::FunctionFeatures; @@ -51,8 +48,8 @@ pub struct DateFunction {} impl DateFunction { fn round_function_creator(round: u32) -> FunctionDescription { - let creator: FactoryCreator = Box::new(move |display_name| -> Result> { - RoundFunction::try_create(display_name, round) + let creator: FactoryCreator = Box::new(move |display_name, args| { + RoundFunction::try_create(display_name, args, round) }); FunctionDescription::creator(creator).features( @@ -102,17 +99,17 @@ impl DateFunction { factory.register("toStartOfWeek", ToStartOfWeekFunction::desc()); //interval functions - factory.register_typed("addYears", AddYearsFunction::desc(1)); - factory.register_typed("addMonths", AddMonthsFunction::desc(1)); - factory.register_typed("addDays", AddDaysFunction::desc(1)); - factory.register_typed("addHours", AddTimesFunction::desc(3600)); - factory.register_typed("addMinutes", AddTimesFunction::desc(60)); - factory.register_typed("addSeconds", AddTimesFunction::desc(1)); - factory.register_typed("subtractYears", AddYearsFunction::desc(-1)); - factory.register_typed("subtractMonths", AddMonthsFunction::desc(-1)); - factory.register_typed("subtractDays", AddDaysFunction::desc(-1)); - factory.register_typed("subtractHours", AddTimesFunction::desc(-3600)); - factory.register_typed("subtractMinutes", AddTimesFunction::desc(-60)); - factory.register_typed("subtractSeconds", AddTimesFunction::desc(-1)); + factory.register("addYears", AddYearsFunction::desc(1)); + factory.register("addMonths", AddMonthsFunction::desc(1)); + factory.register("addDays", AddDaysFunction::desc(1)); + factory.register("addHours", AddTimesFunction::desc(3600)); + factory.register("addMinutes", AddTimesFunction::desc(60)); + factory.register("addSeconds", AddTimesFunction::desc(1)); + factory.register("subtractYears", AddYearsFunction::desc(-1)); + factory.register("subtractMonths", AddMonthsFunction::desc(-1)); + factory.register("subtractDays", AddDaysFunction::desc(-1)); + factory.register("subtractHours", AddTimesFunction::desc(-3600)); + factory.register("subtractMinutes", AddTimesFunction::desc(-60)); + factory.register("subtractSeconds", AddTimesFunction::desc(-1)); } } diff --git a/common/functions/src/scalars/dates/interval_function.rs b/common/functions/src/scalars/dates/interval_function.rs index cdfffef9501c..2c6542584e00 100644 --- a/common/functions/src/scalars/dates/interval_function.rs +++ b/common/functions/src/scalars/dates/interval_function.rs @@ -32,11 +32,11 @@ use crate::define_datetime64_add_year_months; use crate::impl_interval_year_month; use crate::scalars::scalar_binary_op; use crate::scalars::EvalContext; -use crate::scalars::FactoryCreatorWithTypes; +use crate::scalars::FactoryCreator; use crate::scalars::Function; use crate::scalars::FunctionContext; +use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; -use crate::scalars::TypedFunctionDescription; pub struct IntervalFunctionCreator { t: PhantomData, @@ -95,11 +95,11 @@ where T: IntervalArithmeticImpl + Send + Sync + Clone + 'static }) } - pub fn desc(factor: i64) -> TypedFunctionDescription { - let function_creator: FactoryCreatorWithTypes = + pub fn desc(factor: i64) -> FunctionDescription { + let function_creator: FactoryCreator = Box::new(move |display_name, args| Self::try_create_func(display_name, factor, args)); - TypedFunctionDescription::creator(function_creator) + FunctionDescription::creator(function_creator) .features(FunctionFeatures::default().deterministic().num_arguments(2)) } } @@ -150,8 +150,8 @@ where self.display_name.as_str() } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(self.result_type.clone()) + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( diff --git a/common/functions/src/scalars/dates/now.rs b/common/functions/src/scalars/dates/now.rs index efab388c34e3..489c2993fa1c 100644 --- a/common/functions/src/scalars/dates/now.rs +++ b/common/functions/src/scalars/dates/now.rs @@ -32,7 +32,7 @@ pub struct NowFunction { } impl NowFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(NowFunction { display_name: display_name.to_string(), })) @@ -48,11 +48,8 @@ impl Function for NowFunction { self.display_name.as_str() } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(DateTime32Type::arc(None)) + fn return_type(&self) -> DataTypePtr { + DateTime32Type::arc(None) } fn eval( diff --git a/common/functions/src/scalars/dates/number_function.rs b/common/functions/src/scalars/dates/number_function.rs index fe02b037512a..765dd725ad8b 100644 --- a/common/functions/src/scalars/dates/number_function.rs +++ b/common/functions/src/scalars/dates/number_function.rs @@ -29,6 +29,7 @@ use crate::scalars::function_factory::FunctionDescription; use crate::scalars::scalar_unary_op; use crate::scalars::CastFunction; use crate::scalars::EvalContext; +use crate::scalars::FactoryCreator; use crate::scalars::Function; use crate::scalars::FunctionAdapter; use crate::scalars::FunctionContext; @@ -57,7 +58,7 @@ pub trait NumberOperator { None } - fn return_type() -> Option { + fn return_type() -> Option { None } } @@ -111,7 +112,7 @@ impl NumberOperator for ToStartOfYear { get_day(end) as u16 } - fn return_type() -> Option { + fn return_type() -> Option { Some(Date16Type::arc()) } } @@ -132,7 +133,7 @@ impl NumberOperator for ToStartOfISOYear { get_day(end) as u16 } - fn return_type() -> Option { + fn return_type() -> Option { Some(Date16Type::arc()) } } @@ -149,7 +150,7 @@ impl NumberOperator for ToStartOfQuarter { get_day(date) as u16 } - fn return_type() -> Option { + fn return_type() -> Option { Some(Date16Type::arc()) } } @@ -165,7 +166,7 @@ impl NumberOperator for ToStartOfMonth { get_day(date) as u16 } - fn return_type() -> Option { + fn return_type() -> Option { Some(Date16Type::arc()) } } @@ -265,7 +266,10 @@ impl NumberOperator for ToMinute { // ToMinute is NOT a monotonic function in general, unless the time range is within the same hour. fn factor_function() -> Option> { - Some(RoundFunction::try_create("toStartOfHour", 60 * 60).unwrap()) + Some( + RoundFunction::try_create("toStartOfHour", &[&DateTime32Type::arc(None)], 60 * 60) + .unwrap(), + ) } } @@ -281,7 +285,10 @@ impl NumberOperator for ToSecond { // ToSecond is NOT a monotonic function in general, unless the time range is within the same minute. fn factor_function() -> Option> { - Some(RoundFunction::try_create("toStartOfMinute", 60).unwrap()) + Some( + RoundFunction::try_create("toStartOfMinute", &[&DateTime32Type::arc(None)], 60) + .unwrap(), + ) } } @@ -311,7 +318,7 @@ impl NumberOperator for ToYear { impl NumberFunction where T: NumberOperator + Clone + Sync + Send + 'static, - R: PrimitiveType + Clone + ToDataType + common_datavalues::Scalar = R>, + R: PrimitiveType + Clone + ToDataType + Scalar = R>, { pub fn try_create(display_name: &str) -> Result> { Ok(Box::new(NumberFunction:: { @@ -328,7 +335,10 @@ where features = features.deterministic(); } - FunctionDescription::creator(Box::new(Self::try_create)).features(features) + let function_creator: FactoryCreator = + Box::new(move |display_name, _args| Self::try_create(display_name)); + + FunctionDescription::creator(function_creator).features(features) } } @@ -341,13 +351,10 @@ where self.display_name.as_str() } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { + fn return_type(&self) -> DataTypePtr { match T::return_type() { - None => Ok(R::to_data_type()), - Some(v) => Ok(v), + None => R::to_data_type(), + Some(v) => v, } } diff --git a/common/functions/src/scalars/dates/round_function.rs b/common/functions/src/scalars/dates/round_function.rs index bd57290f66eb..a91e63b44e9b 100644 --- a/common/functions/src/scalars/dates/round_function.rs +++ b/common/functions/src/scalars/dates/round_function.rs @@ -31,7 +31,19 @@ pub struct RoundFunction { } impl RoundFunction { - pub fn try_create(display_name: &str, round: u32) -> Result> { + pub fn try_create( + display_name: &str, + args: &[&DataTypePtr], + round: u32, + ) -> Result> { + if args[0].data_type_id() != TypeID::DateTime32 { + return Err(ErrorCode::BadDataValueType(format!( + "Function {} must have a DateTime type as argument, but got {}", + display_name, + args[0].name(), + ))); + } + let s = Self { display_name: display_name.to_owned(), round, @@ -54,19 +66,8 @@ impl Function for RoundFunction { self.display_name.as_str() } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - if args[0].data_type_id() == TypeID::DateTime32 { - return Ok(DateTime32Type::arc(None)); - } else { - return Err(ErrorCode::BadDataValueType(format!( - "Function {} must have a DateTime type as argument, but got {}", - self.display_name, - args[0].name(), - ))); - } + fn return_type(&self) -> DataTypePtr { + DateTime32Type::arc(None) } fn eval( diff --git a/common/functions/src/scalars/dates/simple_date.rs b/common/functions/src/scalars/dates/simple_date.rs index ec12a352e4d1..1bb94ae129ba 100644 --- a/common/functions/src/scalars/dates/simple_date.rs +++ b/common/functions/src/scalars/dates/simple_date.rs @@ -80,7 +80,7 @@ impl NoArgDateFunction for Tomorrow { impl SimpleFunction where T: NoArgDateFunction + Clone + Sync + Send + 'static { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(SimpleFunction:: { display_name: display_name.to_string(), t: PhantomData, @@ -100,11 +100,8 @@ where T: NoArgDateFunction + Clone + Sync + Send + 'static self.display_name.as_str() } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(Date16Type::arc()) + fn return_type(&self) -> DataTypePtr { + Date16Type::arc() } fn eval( diff --git a/common/functions/src/scalars/dates/week_date.rs b/common/functions/src/scalars/dates/week_date.rs index 0147852768b4..a05a103c1320 100644 --- a/common/functions/src/scalars/dates/week_date.rs +++ b/common/functions/src/scalars/dates/week_date.rs @@ -44,7 +44,7 @@ pub struct WeekFunction { pub trait WeekResultFunction { const IS_DETERMINISTIC: bool; - fn return_type() -> Result; + fn return_type() -> DataTypePtr; fn to_number(_value: DateTime, mode: u64) -> R; fn factor_function() -> Option> { None @@ -57,8 +57,8 @@ pub struct ToStartOfWeek; impl WeekResultFunction for ToStartOfWeek { const IS_DETERMINISTIC: bool = true; - fn return_type() -> Result { - Ok(Date16Type::arc()) + fn return_type() -> DataTypePtr { + Date16Type::arc() } fn to_number(value: DateTime, week_mode: u64) -> u32 { let mut weekday = value.weekday().number_from_sunday(); @@ -80,7 +80,12 @@ where for<'a> R: Scalar = R>, for<'a> R: ScalarRef<'a, ScalarType = R, ColumnType = PrimitiveColumn>, { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_date_or_datetime(args[0])?; + if args.len() > 1 { + assert_numeric(args[1])?; + } + Ok(Box::new(WeekFunction:: { display_name: display_name.to_string(), t: PhantomData, @@ -113,11 +118,7 @@ where self.display_name.as_str() } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_date_or_datetime(args[0])?; - if args.len() > 1 { - assert_numeric(args[1])?; - } + fn return_type(&self) -> DataTypePtr { T::return_type() } diff --git a/common/functions/src/scalars/expressions/cast.rs b/common/functions/src/scalars/expressions/cast.rs index 55f9d188c9ca..4d0df897f242 100644 --- a/common/functions/src/scalars/expressions/cast.rs +++ b/common/functions/src/scalars/expressions/cast.rs @@ -64,8 +64,8 @@ impl Function for CastFunction { "CastFunction" } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(self.cast_type.clone()) + fn return_type(&self) -> DataTypePtr { + self.cast_type.clone() } fn eval( diff --git a/common/functions/src/scalars/expressions/expression.rs b/common/functions/src/scalars/expressions/expression.rs index 875c29de7a5a..986bd8f9038c 100644 --- a/common/functions/src/scalars/expressions/expression.rs +++ b/common/functions/src/scalars/expressions/expression.rs @@ -38,7 +38,7 @@ impl ToCastFunction { }; let function_creator: FactoryCreator = - Box::new(move |display_name| CastFunction::create(display_name, type_name)); + Box::new(move |display_name, _args| CastFunction::create(display_name, type_name)); Ok(FunctionDescription::creator(function_creator).features(features)) } diff --git a/common/functions/src/scalars/function.rs b/common/functions/src/scalars/function.rs index 6667c35eeff9..25bd63e9b63a 100644 --- a/common/functions/src/scalars/function.rs +++ b/common/functions/src/scalars/function.rs @@ -43,7 +43,7 @@ pub trait Function: fmt::Display + Sync + Send + DynClone { } /// The method returns the return_type of this function. - fn return_type(&self, args: &[&DataTypePtr]) -> Result; + fn return_type(&self) -> DataTypePtr; /// Evaluate the function, e.g. run/execute the function. fn eval( diff --git a/common/functions/src/scalars/function_adapter.rs b/common/functions/src/scalars/function_adapter.rs index 6b4754b52d60..e142eb59222e 100644 --- a/common/functions/src/scalars/function_adapter.rs +++ b/common/functions/src/scalars/function_adapter.rs @@ -35,53 +35,47 @@ use common_datavalues::TypeID; use common_exception::Result; use super::Function; +use super::FunctionDescription; use super::Monotonicity; -use super::TypedFunctionDescription; +use crate::scalars::FunctionContext; #[derive(Clone)] pub struct FunctionAdapter { inner: Option>, - passthrough_null: bool, + has_nullable: bool, } impl FunctionAdapter { - pub fn create(inner: Box, passthrough_null: bool) -> Box { + pub fn create(inner: Box, has_nullable: bool) -> Box { Box::new(Self { inner: Some(inner), - passthrough_null, + has_nullable, }) } - pub fn create_some( - inner: Option>, - passthrough_null: bool, - ) -> Box { - Box::new(Self { - inner, - passthrough_null, - }) - } - - pub fn try_create_by_typed( - desc: &TypedFunctionDescription, + pub fn try_create( + desc: &FunctionDescription, name: &str, args: &[&DataTypePtr], ) -> Result> { - let passthrough_null = desc.features.passthrough_null; - - let inner = if passthrough_null { + let (inner, has_nullable) = if desc.features.passthrough_null { // one is null, result is null if args.iter().any(|v| v.data_type_id() == TypeID::Null) { - return Ok(Self::create_some(None, true)); + return Ok(Box::new(Self { + inner: None, + has_nullable: false, + })); } + + let has_nullable = args.iter().any(|v| v.is_nullable()); let types = args.iter().map(|v| remove_nullable(v)).collect::>(); let types = types.iter().collect::>(); - (desc.typed_function_creator)(name, &types)? + ((desc.function_creator)(name, &types)?, has_nullable) } else { - (desc.typed_function_creator)(name, args)? + ((desc.function_creator)(name, args)?, false) }; - Ok(Self::create(inner, passthrough_null)) + Ok(Self::create(inner, has_nullable)) } } @@ -90,31 +84,18 @@ impl Function for FunctionAdapter { self.inner.as_ref().map_or("null", |v| v.name()) } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { + fn return_type(&self) -> DataTypePtr { if self.inner.is_none() { - return Ok(NullType::arc()); + return NullType::arc(); } let inner = self.inner.as_ref().unwrap(); + let typ = inner.return_type(); - if self.passthrough_null { - let has_null = args.iter().any(|v| v.is_null()); - if has_null { - return Ok(NullType::arc()); - } - - let has_nullable = args.iter().any(|v| v.is_nullable()); - let types = args.iter().map(|v| remove_nullable(v)).collect::>(); - let types = types.iter().collect::>(); - let typ = inner.return_type(&types)?; - - if has_nullable { - Ok(wrap_nullable(&typ)) - } else { - Ok(typ) - } + if self.has_nullable { + wrap_nullable(&typ) } else { - inner.return_type(args) + typ } } @@ -147,7 +128,7 @@ impl Function for FunctionAdapter { .collect::>(); let col = self.eval(&columns, 1, func_ctx)?; - let col = if col.is_const() && col.len() == 1 { + let col = if col.is_const() && col.len() != input_rows { col.replicate(&[input_rows]) } else if col.is_null() { NullColumn::new(input_rows).arc() @@ -157,65 +138,56 @@ impl Function for FunctionAdapter { return Ok(col); } - // nullable or null - if self.passthrough_null { - if columns - .iter() - .any(|v| v.data_type().data_type_id() == TypeID::Null) - { - return Ok(Arc::new(NullColumn::new(input_rows))); + // nullable + if self.has_nullable && columns.iter().any(|v| v.data_type().is_nullable()) { + let mut validity: Option = None; + + let mut input = Vec::with_capacity(columns.len()); + for v in columns.iter() { + let (is_all_null, valid) = v.column().validity(); + if is_all_null { + // If only null, return null directly. + let inner_type = remove_nullable(&inner.return_type()); + return Ok(NullableColumn::wrap_inner( + inner_type + .create_constant_column(&inner_type.default_value(), input_rows)?, + Some(valid.unwrap().clone()), + )); + } + validity = combine_validities_2(validity.clone(), valid.cloned()); + + let ty = remove_nullable(v.data_type()); + let f = v.field(); + let col = Series::remove_nullable(v.column()); + let col = ColumnWithField::new(col, DataField::new(f.name(), ty)); + input.push(col); } - if columns.iter().any(|v| v.data_type().is_nullable()) { - let mut validity: Option = None; - let mut has_all_null = false; - - let columns = columns - .iter() - .map(|v| { - let (is_all_null, valid) = v.column().validity(); - if is_all_null { - has_all_null = true; - let mut v = MutableBitmap::with_capacity(input_rows); - v.extend_constant(input_rows, false); - validity = Some(v.into()); - } else if !has_all_null { - validity = combine_validities_2(validity.clone(), valid.cloned()); - } - - let ty = remove_nullable(v.data_type()); - let f = v.field(); - let col = Series::remove_nullable(v.column()); - ColumnWithField::new(col, DataField::new(f.name(), ty)) - }) - .collect::>(); - - let col = self.eval(&columns, input_rows, func_ctx)?; - - // The'try' series functions always return Null when they failed the try. - // For example, try_inet_aton("helloworld") will return Null because it failed to parse "helloworld" to a valid IP address. - // The same thing may happen on other 'try' functions. So we need to merge the validity. - if col.is_nullable() { - let (_, bitmap) = col.validity(); - validity = validity.map_or(combine_validities(bitmap, None), |v| { - combine_validities(bitmap, Some(&v)) - }) - } + let col = self.eval(&input, input_rows, func_ctx)?; - let validity = validity.unwrap_or({ - let mut v = MutableBitmap::with_capacity(input_rows); - v.extend_constant(input_rows, true); - v.into() - }); - - let col = if col.is_nullable() { - let nullable_column: &NullableColumn = Series::check_get(&col)?; - NullableColumn::wrap_inner(nullable_column.inner().clone(), Some(validity)) - } else { - NullableColumn::wrap_inner(col, Some(validity)) - }; - return Ok(col); + // The'try' series functions always return Null when they failed the try. + // For example, try_inet_aton("helloworld") will return Null because it failed to parse "helloworld" to a valid IP address. + // The same thing may happen on other 'try' functions. So we need to merge the validity. + if col.is_nullable() { + let (_, bitmap) = col.validity(); + validity = validity.map_or(combine_validities(bitmap, None), |v| { + combine_validities(bitmap, Some(&v)) + }) } + + let validity = validity.unwrap_or({ + let mut v = MutableBitmap::with_capacity(input_rows); + v.extend_constant(input_rows, true); + v.into() + }); + + let col = if col.is_nullable() { + let nullable_column: &NullableColumn = Series::check_get(&col)?; + NullableColumn::wrap_inner(nullable_column.inner().clone(), Some(validity)) + } else { + NullableColumn::wrap_inner(col, Some(validity)) + }; + return Ok(col); } inner.eval(columns, input_rows, func_ctx) @@ -245,4 +217,3 @@ impl std::fmt::Display for FunctionAdapter { } } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/function_factory.rs b/common/functions/src/scalars/function_factory.rs index 47d13433db4d..9359835275a9 100644 --- a/common/functions/src/scalars/function_factory.rs +++ b/common/functions/src/scalars/function_factory.rs @@ -38,14 +38,12 @@ use super::TupleClassFunction; use crate::scalars::DateFunction; use crate::scalars::UUIDFunction; -pub type FactoryCreator = Box Result> + Send + Sync>; - -pub type FactoryCreatorWithTypes = +pub type FactoryCreator = Box Result> + Send + Sync>; pub struct FunctionDescription { pub(crate) features: FunctionFeatures, - function_creator: FactoryCreator, + pub(crate) function_creator: FactoryCreator, } impl FunctionDescription { @@ -63,29 +61,8 @@ impl FunctionDescription { } } -pub struct TypedFunctionDescription { - pub(crate) features: FunctionFeatures, - pub typed_function_creator: FactoryCreatorWithTypes, -} - -impl TypedFunctionDescription { - pub fn creator(creator: FactoryCreatorWithTypes) -> TypedFunctionDescription { - TypedFunctionDescription { - typed_function_creator: creator, - features: FunctionFeatures::default(), - } - } - - #[must_use] - pub fn features(mut self, features: FunctionFeatures) -> TypedFunctionDescription { - self.features = features; - self - } -} - pub struct FunctionFactory { case_insensitive_desc: HashMap, - case_insensitive_typed_desc: HashMap, } static FUNCTION_FACTORY: Lazy> = Lazy::new(|| { @@ -113,7 +90,6 @@ impl FunctionFactory { pub(in crate::scalars::function_factory) fn create() -> FunctionFactory { FunctionFactory { case_insensitive_desc: Default::default(), - case_insensitive_typed_desc: Default::default(), } } @@ -126,66 +102,46 @@ impl FunctionFactory { case_insensitive_desc.insert(name.to_lowercase(), desc); } - pub fn register_typed(&mut self, name: &str, desc: TypedFunctionDescription) { - let case_insensitive_typed_desc = &mut self.case_insensitive_typed_desc; - case_insensitive_typed_desc.insert(name.to_lowercase(), desc); - } - pub fn get(&self, name: impl AsRef, args: &[&DataTypePtr]) -> Result> { let origin_name = name.as_ref(); let lowercase_name = origin_name.to_lowercase(); - match self.case_insensitive_desc.get(&lowercase_name) { - // TODO(Winter): we should write similar function names into error message if function name is not found. - None => match self.case_insensitive_typed_desc.get(&lowercase_name) { - None => Err(ErrorCode::UnknownFunction(format!( - "Unsupported Function: {}", - origin_name - ))), - Some(desc) => FunctionAdapter::try_create_by_typed(desc, origin_name, args), - }, - Some(desc) => { - let inner = (desc.function_creator)(origin_name)?; - Ok(FunctionAdapter::create( - inner, - desc.features.passthrough_null, - )) - } - } + let desc = self + .case_insensitive_desc + .get(&lowercase_name) + .ok_or_else(|| { + // TODO(Winter): we should write similar function names into error message if function name is not found. + ErrorCode::UnknownFunction(format!("Unsupported Function: {}", origin_name)) + })?; + + FunctionAdapter::try_create(desc, origin_name, args) } pub fn get_features(&self, name: impl AsRef) -> Result { let origin_name = name.as_ref(); let lowercase_name = origin_name.to_lowercase(); - match self.case_insensitive_desc.get(&lowercase_name) { - // TODO(Winter): we should write similar function names into error message if function name is not found. - None => match self.case_insensitive_typed_desc.get(&lowercase_name) { - None => Err(ErrorCode::UnknownFunction(format!( - "Unsupported Function: {}", - origin_name - ))), - Some(desc) => Ok(desc.features.clone()), - }, - Some(desc) => Ok(desc.features.clone()), - } + let desc = self + .case_insensitive_desc + .get(&lowercase_name) + .ok_or_else(|| { + // TODO(Winter): we should write similar function names into error message if function name is not found. + ErrorCode::UnknownFunction(format!("Unsupported Function: {}", origin_name)) + })?; + + Ok(desc.features.clone()) } pub fn check(&self, name: impl AsRef) -> bool { let origin_name = name.as_ref(); let lowercase_name = origin_name.to_lowercase(); - if self.case_insensitive_desc.contains_key(&lowercase_name) { - return true; - } - self.case_insensitive_typed_desc - .contains_key(&lowercase_name) + self.case_insensitive_desc.contains_key(&lowercase_name) } pub fn registered_names(&self) -> Vec { self.case_insensitive_desc .keys() - .chain(self.case_insensitive_typed_desc.keys()) .cloned() .collect::>() } @@ -195,12 +151,6 @@ impl FunctionFactory { .values() .into_iter() .map(|v| &v.features) - .chain( - self.case_insensitive_typed_desc - .values() - .into_iter() - .map(|v| &v.features), - ) .cloned() .collect::>() } diff --git a/common/functions/src/scalars/hashes/city64_with_seed.rs b/common/functions/src/scalars/hashes/city64_with_seed.rs index 0af20f9b3077..d13f6b0a6cc7 100644 --- a/common/functions/src/scalars/hashes/city64_with_seed.rs +++ b/common/functions/src/scalars/hashes/city64_with_seed.rs @@ -61,27 +61,7 @@ pub struct City64WithSeedFunction { // CityHash64WithSeed(value, seed) impl City64WithSeedFunction { - pub fn try_create(display_name: &str) -> Result> { - Ok(Box::new(City64WithSeedFunction { - display_name: display_name.to_string(), - })) - } - - pub fn desc() -> FunctionDescription { - FunctionDescription::creator(Box::new(Self::try_create)) - .features(FunctionFeatures::default().deterministic().num_arguments(2)) - } -} - -impl Function for City64WithSeedFunction { - fn name(&self) -> &str { - &*self.display_name - } - - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { if !matches!( args[0].data_type_id(), TypeID::UInt8 @@ -113,7 +93,25 @@ impl Function for City64WithSeedFunction { args[1] ))); } - Ok(UInt64Type::arc()) + + Ok(Box::new(City64WithSeedFunction { + display_name: display_name.to_string(), + })) + } + + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create)) + .features(FunctionFeatures::default().deterministic().num_arguments(2)) + } +} + +impl Function for City64WithSeedFunction { + fn name(&self) -> &str { + &*self.display_name + } + + fn return_type(&self) -> DataTypePtr { + UInt64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/hashes/hash_base.rs b/common/functions/src/scalars/hashes/hash_base.rs index ebe648b640fd..f07db9c8823a 100644 --- a/common/functions/src/scalars/hashes/hash_base.rs +++ b/common/functions/src/scalars/hashes/hash_base.rs @@ -57,7 +57,7 @@ where H: Hasher + Default + Clone + Sync + Send + 'static, R: Scalar + Clone + FromPrimitive + ToDataType + Sync + Send, { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(BaseHashFunction:: { display_name: display_name.to_string(), h: PhantomData, @@ -80,11 +80,8 @@ where self.display_name.as_str() } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(R::to_data_type()) + fn return_type(&self) -> DataTypePtr { + R::to_data_type() } fn eval( diff --git a/common/functions/src/scalars/hashes/sha2hash.rs b/common/functions/src/scalars/hashes/sha2hash.rs index 529e6d477ebf..ee04c5b65222 100644 --- a/common/functions/src/scalars/hashes/sha2hash.rs +++ b/common/functions/src/scalars/hashes/sha2hash.rs @@ -34,7 +34,21 @@ pub struct Sha2HashFunction { } impl Sha2HashFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + if args[0].data_type_id() != TypeID::String { + return Err(ErrorCode::IllegalDataType(format!( + "Expected first arg as string type, but got {:?}", + args[0] + ))); + } + + if !args[1].data_type_id().is_numeric() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected second arg as integer type, but got {:?}", + args[1] + ))); + } + Ok(Box::new(Sha2HashFunction { display_name: display_name.to_string(), })) @@ -51,24 +65,8 @@ impl Function for Sha2HashFunction { &*self.display_name } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - if args[0].data_type_id() != TypeID::String { - return Err(ErrorCode::IllegalDataType(format!( - "Expected first arg as string type, but got {:?}", - args[0] - ))); - } - - if !args[1].data_type_id().is_numeric() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected second arg as integer type, but got {:?}", - args[1] - ))); - } - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( diff --git a/common/functions/src/scalars/logics/and.rs b/common/functions/src/scalars/logics/and.rs index 83b60ebdbfba..7c315efda14a 100644 --- a/common/functions/src/scalars/logics/and.rs +++ b/common/functions/src/scalars/logics/and.rs @@ -12,24 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::prelude::*; use common_exception::Result; +use super::logic::LogicExpression; +use super::logic::LogicFunctionImpl; use super::logic::LogicOperator; -use super::LogicFunction; +use crate::calcute; +use crate::impl_logic_expression; +use crate::scalars::cast_column_field; use crate::scalars::Function; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; +impl_logic_expression!(LogicAndExpression, &, |lhs: bool, rhs: bool, lhs_v: bool, rhs_v: bool| -> (bool, bool) { + (lhs & rhs, (lhs_v & rhs_v) | (!lhs & lhs_v) | (!rhs & rhs_v)) +}); + #[derive(Clone)] pub struct LogicAndFunction; impl LogicAndFunction { - pub fn try_create(_display_name: &str) -> Result> { - LogicFunction::try_create(LogicOperator::And) + pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result> { + LogicFunctionImpl::::try_create(LogicOperator::And, args) } pub fn desc() -> FunctionDescription { - FunctionDescription::creator(Box::new(Self::try_create)) - .features(FunctionFeatures::default().deterministic().num_arguments(2)) + FunctionDescription::creator(Box::new(Self::try_create)).features( + FunctionFeatures::default() + .deterministic() + .disable_passthrough_null() + .num_arguments(2), + ) } } diff --git a/common/functions/src/scalars/logics/logic.rs b/common/functions/src/scalars/logics/logic.rs index 5cf377309a9e..02763de08508 100644 --- a/common/functions/src/scalars/logics/logic.rs +++ b/common/functions/src/scalars/logics/logic.rs @@ -12,211 +12,86 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; +use std::marker::PhantomData; use common_datavalues::prelude::*; -use common_exception::ErrorCode; use common_exception::Result; use super::xor::LogicXorFunction; use super::LogicAndFunction; use super::LogicNotFunction; use super::LogicOrFunction; -use crate::scalars::cast_column_field; use crate::scalars::Function; use crate::scalars::FunctionContext; use crate::scalars::FunctionFactory; #[derive(Clone)] -pub struct LogicFunction { - op: LogicOperator, -} - -#[derive(Clone, Debug)] -pub enum LogicOperator { - Not, - And, - Or, - Xor, -} +pub struct LogicFunction; impl LogicFunction { - pub fn try_create(op: LogicOperator) -> Result> { - Ok(Box::new(Self { op })) - } - pub fn register(factory: &mut FunctionFactory) { factory.register("and", LogicAndFunction::desc()); factory.register("or", LogicOrFunction::desc()); factory.register("not", LogicNotFunction::desc()); factory.register("xor", LogicXorFunction::desc()); } +} - fn eval_not(&self, columns: &ColumnsWithField, input_rows: usize) -> Result { - let mut nullable = false; - if columns[0].data_type().is_nullable() { - nullable = true; - } - - let dt = if nullable { - Arc::new(NullableType::create(BooleanType::arc())) - } else { - BooleanType::arc() - }; - - let col = cast_column_field(&columns[0], &dt)?; - - let col_viewer = bool::try_create_viewer(&col)?; - - if nullable { - let mut builder = NullableColumnBuilder::::with_capacity(input_rows); +#[derive(Clone, Debug)] +pub enum LogicOperator { + Not, + And, + Or, + Xor, +} - for (idx, data) in col_viewer.iter().enumerate() { - builder.append(!data, col_viewer.valid_at(idx)); - } +#[derive(Clone)] +pub struct LogicFunctionImpl { + op: LogicOperator, + nullable: bool, + f: PhantomData, +} - Ok(builder.build(input_rows)) - } else { - let mut builder = ColumnBuilder::::with_capacity(input_rows); +pub trait LogicExpression: Sync + Send { + fn eval(columns: &ColumnsWithField, input_rows: usize, nullable: bool) -> Result; +} - for value in col_viewer.iter() { - builder.append(!value); +impl LogicFunctionImpl +where F: LogicExpression + Clone + 'static +{ + pub fn try_create(op: LogicOperator, args: &[&DataTypePtr]) -> Result> { + let nullable = match op { + LogicOperator::And | LogicOperator::Or + if args[0].is_nullable() + || args[1].is_nullable() + || args[0].is_null() + || args[1].is_null() => + { + true } - Ok(builder.build(input_rows)) - } - } - - fn eval_and_not_or(&self, columns: &ColumnsWithField, input_rows: usize) -> Result { - let mut nullable = false; - if columns[0].data_type().is_nullable() || columns[1].data_type().is_nullable() { - nullable = true; - } - - let dt = if nullable { - Arc::new(NullableType::create(BooleanType::arc())) - } else { - BooleanType::arc() + _ => false, }; - let lhs = cast_column_field(&columns[0], &dt)?; - let rhs = cast_column_field(&columns[1], &dt)?; - - if nullable { - let lhs_viewer = bool::try_create_viewer(&lhs)?; - let rhs_viewer = bool::try_create_viewer(&rhs)?; - - let lhs_viewer_iter = lhs_viewer.iter(); - let rhs_viewer_iter = rhs_viewer.iter(); - - let mut builder = NullableColumnBuilder::::with_capacity(input_rows); - - macro_rules! calcute_with_null { - ($lhs_viewer: expr, $rhs_viewer: expr, $lhs_viewer_iter: expr, $rhs_viewer_iter: expr, $builder: expr, $func: expr) => { - for (a, (idx, b)) in $lhs_viewer_iter.zip($rhs_viewer_iter.enumerate()) { - let (val, valid) = - $func(a, b, $lhs_viewer.valid_at(idx), $rhs_viewer.valid_at(idx)); - $builder.append(val, valid); - } - }; - } - - match self.op { - LogicOperator::And => calcute_with_null!( - lhs_viewer, - rhs_viewer, - lhs_viewer_iter, - rhs_viewer_iter, - builder, - |lhs: bool, rhs: bool, l_valid: bool, r_valid: bool| -> (bool, bool) { - (lhs & rhs, l_valid & r_valid) - } - ), - LogicOperator::Or => calcute_with_null!( - lhs_viewer, - rhs_viewer, - lhs_viewer_iter, - rhs_viewer_iter, - builder, - |lhs: bool, rhs: bool, _l_valid: bool, _r_valid: bool| -> (bool, bool) { - (lhs || rhs, lhs || rhs) - } - ), - LogicOperator::Xor => calcute_with_null!( - lhs_viewer, - rhs_viewer, - lhs_viewer_iter, - rhs_viewer_iter, - builder, - |lhs: bool, rhs: bool, l_valid: bool, r_valid: bool| -> (bool, bool) { - (lhs ^ rhs, l_valid & r_valid) - } - ), - LogicOperator::Not => return Err(ErrorCode::LogicalError("never happen")), - }; - - Ok(builder.build(input_rows)) - } else { - let lhs_viewer = bool::try_create_viewer(&lhs)?; - let rhs_viewer = bool::try_create_viewer(&rhs)?; - - let mut builder = ColumnBuilder::::with_capacity(input_rows); - - macro_rules! calcute { - ($lhs_viewer: expr, $rhs_viewer: expr, $builder: expr, $func: expr) => { - for (a, b) in ($lhs_viewer.iter().zip($rhs_viewer.iter())) { - $builder.append($func(a, b)); - } - }; - } - - match self.op { - LogicOperator::And => calcute!( - lhs_viewer, - rhs_viewer, - builder, - |lhs: bool, rhs: bool| -> bool { lhs & rhs } - ), - LogicOperator::Or => calcute!( - lhs_viewer, - rhs_viewer, - builder, - |lhs: bool, rhs: bool| -> bool { lhs || rhs } - ), - LogicOperator::Xor => calcute!( - lhs_viewer, - rhs_viewer, - builder, - |lhs: bool, rhs: bool| -> bool { lhs ^ rhs } - ), - LogicOperator::Not => return Err(ErrorCode::LogicalError("never happen")), - }; - - Ok(builder.build(input_rows)) - } + Ok(Box::new(Self { + op, + nullable, + f: PhantomData, + })) } } -impl Function for LogicFunction { +impl Function for LogicFunctionImpl +where F: LogicExpression + Clone +{ fn name(&self) -> &str { "LogicFunction" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - match self.op { - LogicOperator::Not => { - if args[0].is_nullable() { - Ok(Arc::new(NullableType::create(BooleanType::arc()))) - } else { - Ok(BooleanType::arc()) - } - } - _ => { - if args[0].is_nullable() || args[1].is_nullable() { - Ok(Arc::new(NullableType::create(BooleanType::arc()))) - } else { - Ok(BooleanType::arc()) - } - } + fn return_type(&self) -> DataTypePtr { + if self.nullable { + NullableType::arc(BooleanType::arc()) + } else { + BooleanType::arc() } } @@ -226,14 +101,13 @@ impl Function for LogicFunction { input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - match self.op { - LogicOperator::Not => self.eval_not(columns, input_rows), - _ => self.eval_and_not_or(columns, input_rows), - } + F::eval(columns, input_rows, self.nullable) } } -impl std::fmt::Display for LogicFunction { +impl std::fmt::Display for LogicFunctionImpl +where F: LogicExpression +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self.op) } diff --git a/common/functions/src/scalars/logics/macros.rs b/common/functions/src/scalars/logics/macros.rs new file mode 100644 index 000000000000..082bea777e78 --- /dev/null +++ b/common/functions/src/scalars/logics/macros.rs @@ -0,0 +1,73 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[macro_export] +macro_rules! calcute { + ($lhs_viewer: expr, $rhs_viewer: expr, $builder: expr, $func: expr) => { + for (a, b) in ($lhs_viewer.iter().zip($rhs_viewer.iter())) { + $builder.append($func(a, b)); + } + }; +} + +#[macro_export] +macro_rules! impl_logic_expression { + ($name: ident, $method: tt, $func: expr) => { + #[derive(Clone)] + pub struct $name; + + impl LogicExpression for $name { + fn eval(columns: &ColumnsWithField, input_rows: usize, nullable: bool) -> Result { + let dt = if nullable { + NullableType::arc(BooleanType::arc()) + } else { + BooleanType::arc() + }; + + let lhs = cast_column_field(&columns[0], &dt)?; + let rhs = cast_column_field(&columns[1], &dt)?; + + if nullable { + let lhs_viewer = bool::try_create_viewer(&lhs)?; + let rhs_viewer = bool::try_create_viewer(&rhs)?; + + let lhs_viewer_iter = lhs_viewer.iter(); + let rhs_viewer_iter = rhs_viewer.iter(); + + let mut builder = NullableColumnBuilder::::with_capacity(input_rows); + + for (a, (idx, b)) in lhs_viewer_iter.zip(rhs_viewer_iter.enumerate()) { + let (val, valid) = $func(a, b, lhs_viewer.valid_at(idx), rhs_viewer.valid_at(idx)); + builder.append(val, valid); + } + + Ok(builder.build(input_rows)) + } else { + let lhs_viewer = bool::try_create_viewer(&lhs)?; + let rhs_viewer = bool::try_create_viewer(&rhs)?; + + let mut builder = ColumnBuilder::::with_capacity(input_rows); + + calcute!(lhs_viewer, rhs_viewer, builder, |lhs: bool, + rhs: bool| + -> bool { + lhs $method rhs + }); + + Ok(builder.build(input_rows)) + } + } + } + }; +} diff --git a/common/functions/src/scalars/logics/mod.rs b/common/functions/src/scalars/logics/mod.rs index 2ba26a397dcf..08011b08ea98 100644 --- a/common/functions/src/scalars/logics/mod.rs +++ b/common/functions/src/scalars/logics/mod.rs @@ -17,6 +17,8 @@ mod logic; mod not; mod or; mod xor; +#[macro_use] +mod macros; pub use and::LogicAndFunction; pub use logic::LogicFunction; diff --git a/common/functions/src/scalars/logics/not.rs b/common/functions/src/scalars/logics/not.rs index c5a6c89a4834..d1d0cfb55ef3 100644 --- a/common/functions/src/scalars/logics/not.rs +++ b/common/functions/src/scalars/logics/not.rs @@ -12,20 +12,41 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::prelude::*; use common_exception::Result; +use super::logic::LogicExpression; +use super::logic::LogicFunctionImpl; use super::logic::LogicOperator; -use super::LogicFunction; +use crate::scalars::cast_column_field; use crate::scalars::Function; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; +#[derive(Clone)] +pub struct LogicNotExpression; + +impl LogicExpression for LogicNotExpression { + fn eval(columns: &ColumnsWithField, input_rows: usize, _nullable: bool) -> Result { + let col = cast_column_field(&columns[0], &BooleanType::arc())?; + + let col_viewer = bool::try_create_viewer(&col)?; + + let mut builder = ColumnBuilder::::with_capacity(input_rows); + + for value in col_viewer.iter() { + builder.append(!value); + } + Ok(builder.build(input_rows)) + } +} + #[derive(Clone)] pub struct LogicNotFunction; impl LogicNotFunction { - pub fn try_create(_display_name: &str) -> Result> { - LogicFunction::try_create(LogicOperator::Not) + pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result> { + LogicFunctionImpl::::try_create(LogicOperator::Not, args) } pub fn desc() -> FunctionDescription { diff --git a/common/functions/src/scalars/logics/or.rs b/common/functions/src/scalars/logics/or.rs index c5987113a93d..f8f9bfa135cb 100644 --- a/common/functions/src/scalars/logics/or.rs +++ b/common/functions/src/scalars/logics/or.rs @@ -12,20 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::prelude::*; use common_exception::Result; +use super::logic::LogicExpression; +use super::logic::LogicFunctionImpl; use super::logic::LogicOperator; -use super::LogicFunction; +use crate::calcute; +use crate::impl_logic_expression; +use crate::scalars::cast_column_field; use crate::scalars::Function; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; +impl_logic_expression!(LogicOrExpression, |, |lhs: bool, rhs: bool, lhs_v: bool, rhs_v: bool| -> (bool, bool) { + (lhs | rhs, (lhs_v & rhs_v) | (lhs | rhs)) +}); + #[derive(Clone)] pub struct LogicOrFunction; impl LogicOrFunction { - pub fn try_create(_display_name: &str) -> Result> { - LogicFunction::try_create(LogicOperator::Or) + pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result> { + LogicFunctionImpl::::try_create(LogicOperator::Or, args) } pub fn desc() -> FunctionDescription { @@ -33,7 +42,7 @@ impl LogicOrFunction { FunctionFeatures::default() .deterministic() .disable_passthrough_null() - .num_arguments(1), + .num_arguments(2), ) } } diff --git a/common/functions/src/scalars/logics/xor.rs b/common/functions/src/scalars/logics/xor.rs index e11032aa501f..d28c26e66f5d 100644 --- a/common/functions/src/scalars/logics/xor.rs +++ b/common/functions/src/scalars/logics/xor.rs @@ -12,20 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::prelude::*; use common_exception::Result; +use super::logic::LogicExpression; +use super::logic::LogicFunctionImpl; use super::logic::LogicOperator; -use super::LogicFunction; +use crate::calcute; +use crate::scalars::cast_column_field; use crate::scalars::Function; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; +#[derive(Clone)] +pub struct LogicXorExpression; + +impl LogicExpression for LogicXorExpression { + fn eval(columns: &ColumnsWithField, input_rows: usize, _nullable: bool) -> Result { + let lhs = cast_column_field(&columns[0], &BooleanType::arc())?; + let rhs = cast_column_field(&columns[1], &BooleanType::arc())?; + let lhs_viewer = bool::try_create_viewer(&lhs)?; + let rhs_viewer = bool::try_create_viewer(&rhs)?; + + let mut builder = ColumnBuilder::::with_capacity(input_rows); + + calcute!(lhs_viewer, rhs_viewer, builder, |lhs: bool, + rhs: bool| + -> bool { + lhs ^ rhs + }); + Ok(builder.build(input_rows)) + } +} + #[derive(Clone)] pub struct LogicXorFunction; impl LogicXorFunction { - pub fn try_create(_display_name: &str) -> Result> { - LogicFunction::try_create(LogicOperator::Xor) + pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result> { + LogicFunctionImpl::::try_create(LogicOperator::Xor, args) } pub fn desc() -> FunctionDescription { diff --git a/common/functions/src/scalars/maths/abs.rs b/common/functions/src/scalars/maths/abs.rs index 21a1e5d8664a..6e71c855273f 100644 --- a/common/functions/src/scalars/maths/abs.rs +++ b/common/functions/src/scalars/maths/abs.rs @@ -30,12 +30,28 @@ use crate::scalars::Monotonicity; #[derive(Clone)] pub struct AbsFunction { _display_name: String, + result_type: DataTypePtr, } impl AbsFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; + let result_type = match args[0].data_type_id() { + TypeID::Int8 => u8::to_data_type(), + TypeID::Int16 => u16::to_data_type(), + TypeID::Int32 => u32::to_data_type(), + TypeID::Int64 => u64::to_data_type(), + TypeID::UInt8 => u8::to_data_type(), + TypeID::UInt16 => u16::to_data_type(), + TypeID::UInt32 => u32::to_data_type(), + TypeID::UInt64 => u64::to_data_type(), + TypeID::Float32 => f32::to_data_type(), + TypeID::Float64 => f64::to_data_type(), + _ => unreachable!(), + }; Ok(Box::new(AbsFunction { _display_name: display_name.to_string(), + result_type, })) } @@ -76,22 +92,8 @@ impl Function for AbsFunction { "abs" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - let data_type = match args[0].data_type_id() { - TypeID::Int8 => u8::to_data_type(), - TypeID::Int16 => u16::to_data_type(), - TypeID::Int32 => u32::to_data_type(), - TypeID::Int64 => u64::to_data_type(), - TypeID::UInt8 => u8::to_data_type(), - TypeID::UInt16 => u16::to_data_type(), - TypeID::UInt32 => u32::to_data_type(), - TypeID::UInt64 => u64::to_data_type(), - TypeID::Float32 => f32::to_data_type(), - TypeID::Float64 => f64::to_data_type(), - _ => unreachable!(), - }; - Ok(data_type) + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( diff --git a/common/functions/src/scalars/maths/angle.rs b/common/functions/src/scalars/maths/angle.rs index 81919a12b1fe..64fd048524f6 100644 --- a/common/functions/src/scalars/maths/angle.rs +++ b/common/functions/src/scalars/maths/angle.rs @@ -41,7 +41,8 @@ pub trait AngleConvertFunction { impl AngleFunction where T: AngleConvertFunction + Clone + Sync + Send + 'static { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(AngleFunction:: { _display_name: display_name.to_string(), t: PhantomData, @@ -61,9 +62,8 @@ where T: AngleConvertFunction + Clone + Sync + Send + 'static "AngleFunction" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(Float64Type::arc()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/ceil.rs b/common/functions/src/scalars/maths/ceil.rs index 020db20318ee..4b4887c2a1b5 100644 --- a/common/functions/src/scalars/maths/ceil.rs +++ b/common/functions/src/scalars/maths/ceil.rs @@ -35,7 +35,8 @@ pub struct CeilFunction { } impl CeilFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(CeilFunction { display_name: display_name.to_string(), })) @@ -61,9 +62,8 @@ impl Function for CeilFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(Float64Type::arc()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/exp.rs b/common/functions/src/scalars/maths/exp.rs index 2c97c65d0aa1..748af22ba44f 100644 --- a/common/functions/src/scalars/maths/exp.rs +++ b/common/functions/src/scalars/maths/exp.rs @@ -33,7 +33,8 @@ pub struct ExpFunction { } impl ExpFunction { - pub fn try_create(_display_name: &str) -> Result> { + pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(ExpFunction { _display_name: _display_name.to_string(), })) @@ -55,9 +56,8 @@ impl Function for ExpFunction { &*self._display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(Float64Type::arc()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/floor.rs b/common/functions/src/scalars/maths/floor.rs index 164656e5e1b4..d6f8b8f0c25f 100644 --- a/common/functions/src/scalars/maths/floor.rs +++ b/common/functions/src/scalars/maths/floor.rs @@ -25,6 +25,7 @@ use crate::scalars::function_factory::FunctionDescription; use crate::scalars::scalar_unary_op; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionFeatures; use crate::scalars::Monotonicity; @@ -34,7 +35,8 @@ pub struct FloorFunction { } impl FloorFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(FloorFunction { display_name: display_name.to_string(), })) @@ -60,9 +62,8 @@ impl Function for FloorFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(Float64Type::arc()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( @@ -92,4 +93,3 @@ impl fmt::Display for FloorFunction { write!(f, "{}", self.display_name.to_uppercase()) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/maths/log.rs b/common/functions/src/scalars/maths/log.rs index 2629f0b685d0..19a21c39a940 100644 --- a/common/functions/src/scalars/maths/log.rs +++ b/common/functions/src/scalars/maths/log.rs @@ -70,7 +70,10 @@ pub struct GenericLogFunction { } impl GenericLogFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } Ok(Box::new(Self { display_name: display_name.to_string(), t: PhantomData, @@ -104,11 +107,8 @@ impl Function for GenericLogFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_numeric(*arg)?; - } - Ok(f64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/pi.rs b/common/functions/src/scalars/maths/pi.rs index ec70cfe1fe84..9870310cf1a1 100644 --- a/common/functions/src/scalars/maths/pi.rs +++ b/common/functions/src/scalars/maths/pi.rs @@ -20,6 +20,7 @@ use common_exception::Result; use crate::scalars::function_factory::FunctionDescription; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionFeatures; #[derive(Clone)] @@ -28,7 +29,7 @@ pub struct PiFunction { } impl PiFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(PiFunction { display_name: display_name.to_string(), })) @@ -45,8 +46,8 @@ impl Function for PiFunction { &*self.display_name } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(Float64Type::arc()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( @@ -64,4 +65,3 @@ impl fmt::Display for PiFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/maths/pow.rs b/common/functions/src/scalars/maths/pow.rs index 3221020fdd93..c5fe29d13c66 100644 --- a/common/functions/src/scalars/maths/pow.rs +++ b/common/functions/src/scalars/maths/pow.rs @@ -35,7 +35,10 @@ pub struct PowFunction { } impl PowFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } Ok(Box::new(PowFunction { display_name: display_name.to_string(), })) @@ -61,11 +64,8 @@ impl Function for PowFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_numeric(*arg)?; - } - Ok(f64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/random.rs b/common/functions/src/scalars/maths/random.rs index e14f577d9296..af7839110a5b 100644 --- a/common/functions/src/scalars/maths/random.rs +++ b/common/functions/src/scalars/maths/random.rs @@ -35,7 +35,10 @@ pub struct RandomFunction { } impl RandomFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } Ok(Box::new(RandomFunction { display_name: display_name.to_string(), })) @@ -52,11 +55,8 @@ impl Function for RandomFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_numeric(*arg)?; - } - Ok(f64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/round.rs b/common/functions/src/scalars/maths/round.rs index 02997d069fc0..ef44a17ed112 100644 --- a/common/functions/src/scalars/maths/round.rs +++ b/common/functions/src/scalars/maths/round.rs @@ -85,11 +85,8 @@ impl Function for RoundingFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_numeric(*arg)?; - } - Ok(f64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( @@ -175,7 +172,10 @@ pub struct RoundingFunction { } impl RoundingFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } Ok(Box::new(Self { display_name: display_name.to_string(), })) diff --git a/common/functions/src/scalars/maths/sign.rs b/common/functions/src/scalars/maths/sign.rs index a92b4b491503..d4c2de9765ea 100644 --- a/common/functions/src/scalars/maths/sign.rs +++ b/common/functions/src/scalars/maths/sign.rs @@ -36,7 +36,8 @@ pub struct SignFunction { } impl SignFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(SignFunction { display_name: display_name.to_string(), })) @@ -66,9 +67,8 @@ impl Function for SignFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(i8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Int8Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/sqrt.rs b/common/functions/src/scalars/maths/sqrt.rs index 69d9764627c1..18687197264c 100644 --- a/common/functions/src/scalars/maths/sqrt.rs +++ b/common/functions/src/scalars/maths/sqrt.rs @@ -33,7 +33,8 @@ pub struct SqrtFunction { } impl SqrtFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(SqrtFunction { display_name: display_name.to_string(), })) @@ -55,9 +56,8 @@ impl Function for SqrtFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(Float64Type::arc()) + fn return_type(&self) -> DataTypePtr { + Float64Type::arc() } fn eval( diff --git a/common/functions/src/scalars/maths/trigonometric.rs b/common/functions/src/scalars/maths/trigonometric.rs index 6bdec116c1b7..0bd7bf195da9 100644 --- a/common/functions/src/scalars/maths/trigonometric.rs +++ b/common/functions/src/scalars/maths/trigonometric.rs @@ -73,11 +73,8 @@ impl Function for TrigonometricFunction { "TrigonometricFunction" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_numeric(*arg)?; - } - Ok(f64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + f64::to_data_type() } fn eval( @@ -168,7 +165,13 @@ impl fmt::Display for TrigonometricFunction { pub struct TrigonometricSinFunction; impl TrigonometricSinFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::SIN) } @@ -181,7 +184,13 @@ impl TrigonometricSinFunction { pub struct TrigonometricCosFunction; impl TrigonometricCosFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::COS) } @@ -194,7 +203,13 @@ impl TrigonometricCosFunction { pub struct TrigonometricTanFunction; impl TrigonometricTanFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::TAN) } @@ -207,7 +222,13 @@ impl TrigonometricTanFunction { pub struct TrigonometricCotFunction; impl TrigonometricCotFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::COT) } @@ -220,7 +241,13 @@ impl TrigonometricCotFunction { pub struct TrigonometricAsinFunction; impl TrigonometricAsinFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::ASIN) } @@ -233,7 +260,13 @@ impl TrigonometricAsinFunction { pub struct TrigonometricAcosFunction; impl TrigonometricAcosFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::ACOS) } @@ -246,7 +279,13 @@ impl TrigonometricAcosFunction { pub struct TrigonometricAtanFunction; impl TrigonometricAtanFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::ATAN) } @@ -262,7 +301,13 @@ impl TrigonometricAtanFunction { pub struct TrigonometricAtan2Function; impl TrigonometricAtan2Function { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } TrigonometricFunction::try_create_func(Trigonometric::ATAN2) } diff --git a/common/functions/src/scalars/others/exists.rs b/common/functions/src/scalars/others/exists.rs index d6314c3af2f9..7043d6618e0c 100644 --- a/common/functions/src/scalars/others/exists.rs +++ b/common/functions/src/scalars/others/exists.rs @@ -27,7 +27,7 @@ use crate::scalars::FunctionFeatures; pub struct ExistsFunction; impl ExistsFunction { - pub fn try_create(_display_name: &str) -> Result> { + pub fn try_create(_display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(ExistsFunction {})) } @@ -42,11 +42,8 @@ impl Function for ExistsFunction { "ExistsFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(bool::to_data_type()) + fn return_type(&self) -> DataTypePtr { + bool::to_data_type() } fn eval( diff --git a/common/functions/src/scalars/others/ignore.rs b/common/functions/src/scalars/others/ignore.rs index 9f64762434e4..c78fc8d00852 100644 --- a/common/functions/src/scalars/others/ignore.rs +++ b/common/functions/src/scalars/others/ignore.rs @@ -36,7 +36,7 @@ pub struct IgnoreFunction { } impl IgnoreFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(IgnoreFunction { display_name: display_name.to_string(), })) @@ -64,8 +64,8 @@ impl Function for IgnoreFunction { &*self.display_name } - fn return_type(&self, _args: &[&DataTypePtr]) -> Result { - Ok(BooleanType::arc()) + fn return_type(&self) -> DataTypePtr { + BooleanType::arc() } fn eval( diff --git a/common/functions/src/scalars/others/inet_aton.rs b/common/functions/src/scalars/others/inet_aton.rs index ce51237ddc05..fbb72c0d5591 100644 --- a/common/functions/src/scalars/others/inet_aton.rs +++ b/common/functions/src/scalars/others/inet_aton.rs @@ -15,12 +15,12 @@ use std::fmt; use std::net::Ipv4Addr; use std::str; -use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_string; use crate::scalars::Function; use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; @@ -38,20 +38,17 @@ pub struct InetAtonFunctionImpl { } impl InetAtonFunctionImpl { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + Ok(Box::new(InetAtonFunctionImpl:: { display_name: display_name.to_string(), })) } pub fn desc() -> FunctionDescription { - let mut features = FunctionFeatures::default().deterministic().num_arguments(1); - // Null will cause parse error when SUPPRESS_PARSE_ERROR is false. - // In this case we need to check null and skip the parsing, so passthrough_null should be false. - if !SUPPRESS_PARSE_ERROR { - features = features.disable_passthrough_null() - } - FunctionDescription::creator(Box::new(Self::try_create)).features(features) + FunctionDescription::creator(Box::new(Self::try_create)) + .features(FunctionFeatures::default().deterministic().num_arguments(1)) } } @@ -60,26 +57,11 @@ impl Function for InetAtonFunctionImpl Result { - let input_type = remove_nullable(args[0]); - let output_type = match input_type.data_type_id() { - TypeID::Null => return Ok(NullType::arc()), - TypeID::String => Ok(type_primitive::UInt32Type::arc()), - _ => Err(ErrorCode::IllegalDataType(format!( - "Expected string or null type, but got {}", - args[0].name() - ))), - }?; - + fn return_type(&self) -> DataTypePtr { if SUPPRESS_PARSE_ERROR { - // For invalid input, we suppress parse error and return null. So the return type must be nullable. - return Ok(Arc::new(NullableType::create(output_type))); - } - - if args[0].is_nullable() { - Ok(Arc::new(NullableType::create(output_type))) + NullableType::arc(UInt32Type::arc()) } else { - Ok(output_type) + UInt32Type::arc() } } @@ -89,10 +71,6 @@ impl Function for InetAtonFunctionImpl Result { - if columns[0].column().data_type_id() == TypeID::Null { - return NullType::arc().create_constant_column(&DataValue::Null, input_rows); - } - let viewer = Vu8::try_create_viewer(columns[0].column())?; let viewer_iter = viewer.iter(); @@ -100,7 +78,6 @@ impl Function for InetAtonFunctionImpl::with_capacity(input_rows); for (i, input) in viewer_iter.enumerate() { - // We skip the null check because the function has passthrough_null is true. // This is arguably correct because the address parsing is not optimized by SIMD, not quite sure how much we can gain from skipping branch prediction. // Think about the case if we have 1000 rows and 999 are Nulls. let addr_str = String::from_utf8_lossy(input); @@ -115,48 +92,24 @@ impl Function for InetAtonFunctionImpl::with_capacity(input_rows); - for (i, input) in viewer_iter.enumerate() { - if viewer.null_at(i) { - builder.append_null(); - continue; - } - - let addr_str = String::from_utf8_lossy(input); - match addr_str.parse::() { - Ok(addr) => { - let addr_binary: u32 = u32::from(addr); - builder.append(addr_binary, viewer.valid_at(i)); - } - Err(err) => { - return Err(ErrorCode::StrParseError(format!( - "Failed to parse '{}' into a IPV4 address, {}", - addr_str, err - ))); - } + // We skip the null check because the function has passthrough_null is true. + let mut builder = ColumnBuilder::::with_capacity(input_rows); + for input in viewer_iter { + let addr_str = String::from_utf8_lossy(input); + match addr_str.parse::() { + Ok(addr) => { + let addr_binary: u32 = u32::from(addr); + builder.append(addr_binary); } - } - Ok(builder.build(input_rows)) - } else { - let mut builder = ColumnBuilder::::with_capacity(input_rows); - for input in viewer_iter { - let addr_str = String::from_utf8_lossy(input); - match addr_str.parse::() { - Ok(addr) => { - let addr_binary: u32 = u32::from(addr); - builder.append(addr_binary); - } - Err(err) => { - return Err(ErrorCode::StrParseError(format!( - "Failed to parse '{}' into a IPV4 address, {}", - addr_str, err - ))); - } + Err(err) => { + return Err(ErrorCode::StrParseError(format!( + "Failed to parse '{}' into a IPV4 address, {}", + addr_str, err + ))); } } - Ok(builder.build(input_rows)) } + Ok(builder.build(input_rows)) } } diff --git a/common/functions/src/scalars/others/inet_ntoa.rs b/common/functions/src/scalars/others/inet_ntoa.rs index e45bedd68b5e..59125777f75b 100644 --- a/common/functions/src/scalars/others/inet_ntoa.rs +++ b/common/functions/src/scalars/others/inet_ntoa.rs @@ -15,12 +15,11 @@ use std::fmt; use std::net::Ipv4Addr; use std::str; -use std::sync::Arc; use common_datavalues::prelude::*; -use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_numeric; use crate::scalars::cast_with_type; use crate::scalars::CastOptions; use crate::scalars::ExceptionMode; @@ -42,7 +41,9 @@ pub struct InetNtoaFunctionImpl { } impl InetNtoaFunctionImpl { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; + Ok(Box::new(InetNtoaFunctionImpl:: { display_name: display_name.to_string(), })) @@ -59,26 +60,11 @@ impl Function for InetNtoaFunctionImpl Result { - let input_type = args[0]; - if input_type.data_type_id() == TypeID::Null { - return Ok(NullType::arc()); - } - - let output_type = if input_type.data_type_id().is_numeric() { - Ok(StringType::arc()) - } else { - Err(ErrorCode::IllegalDataType(format!( - "Expected numeric or null type, but got {}", - args[0].name() - ))) - }?; - + fn return_type(&self) -> DataTypePtr { if SUPPRESS_CAST_ERROR { - // For invalid input, the function should return null. So the return type must be nullable. - Ok(Arc::new(NullableType::create(output_type))) + NullableType::arc(StringType::arc()) } else { - Ok(output_type) + StringType::arc() } } @@ -88,12 +74,8 @@ impl Function for InetNtoaFunctionImpl Result { - if columns[0].column().data_type_id() == TypeID::Null { - return NullType::arc().create_constant_column(&DataValue::Null, input_rows); - } - if SUPPRESS_CAST_ERROR { - let cast_to: DataTypePtr = Arc::new(NullableType::create(UInt32Type::arc())); + let cast_to: DataTypePtr = NullableType::arc(UInt32Type::arc()); let cast_options = CastOptions { // we allow cast failure exception_mode: ExceptionMode::Zero, diff --git a/common/functions/src/scalars/others/running_difference_function.rs b/common/functions/src/scalars/others/running_difference_function.rs index 9ec7187723ec..083f8eee61a9 100644 --- a/common/functions/src/scalars/others/running_difference_function.rs +++ b/common/functions/src/scalars/others/running_difference_function.rs @@ -15,7 +15,6 @@ use std::fmt; use std::ops::Sub; use std::str; -use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::ErrorCode; @@ -29,12 +28,40 @@ use crate::scalars::FunctionFeatures; #[derive(Clone)] pub struct RunningDifferenceFunction { display_name: String, + result_type: DataTypePtr, } impl RunningDifferenceFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + let nullable = args.iter().any(|arg| arg.is_nullable()); + let dt = remove_nullable(args[0]); + + let output_type = match dt.data_type_id() { + TypeID::Int8 | TypeID::UInt8 => Ok(type_primitive::Int16Type::arc()), + TypeID::Int16 | TypeID::UInt16 | TypeID::Date16 => Ok(type_primitive::Int32Type::arc()), + TypeID::Int32 + | TypeID::UInt32 + | TypeID::Int64 + | TypeID::UInt64 + | TypeID::Date32 + | TypeID::DateTime32 + | TypeID::DateTime64 + | TypeID::Interval => Ok(type_primitive::Int64Type::arc()), + TypeID::Float32 | TypeID::Float64 => Ok(type_primitive::Float64Type::arc()), + _ => Err(ErrorCode::IllegalDataType( + "Argument for function runningDifference must have numeric type", + )), + }?; + + let result_type = if nullable { + NullableType::arc(output_type) + } else { + output_type + }; + Ok(Box::new(RunningDifferenceFunction { display_name: display_name.to_string(), + result_type, })) } @@ -56,32 +83,8 @@ impl Function for RunningDifferenceFunction { self.display_name.as_str() } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - let nullable = args.iter().any(|arg| arg.is_nullable()); - let dt = remove_nullable(args[0]); - - let output_type = match dt.data_type_id() { - TypeID::Int8 | TypeID::UInt8 => Ok(type_primitive::Int16Type::arc()), - TypeID::Int16 | TypeID::UInt16 | TypeID::Date16 => Ok(type_primitive::Int32Type::arc()), - TypeID::Int32 - | TypeID::UInt32 - | TypeID::Int64 - | TypeID::UInt64 - | TypeID::Date32 - | TypeID::DateTime32 - | TypeID::DateTime64 - | TypeID::Interval => Ok(type_primitive::Int64Type::arc()), - TypeID::Float32 | TypeID::Float64 => Ok(type_primitive::Float64Type::arc()), - _ => Result::Err(ErrorCode::IllegalDataType( - "Argument for function runningDifference must have numeric type", - )), - }?; - - if nullable { - Ok(Arc::new(NullableType::create(output_type))) - } else { - Ok(output_type) - } + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( diff --git a/common/functions/src/scalars/others/sleep.rs b/common/functions/src/scalars/others/sleep.rs index 6b82cb23db5e..8b4c16463202 100644 --- a/common/functions/src/scalars/others/sleep.rs +++ b/common/functions/src/scalars/others/sleep.rs @@ -15,11 +15,13 @@ use std::fmt; use std::time::Duration; +use common_datavalues::DataTypePtr; use common_datavalues::DataValue; use common_datavalues::Int8Type; use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_numeric; use crate::scalars::Function; use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; @@ -31,7 +33,8 @@ pub struct SleepFunction { } impl SleepFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(SleepFunction { display_name: display_name.to_string(), })) @@ -48,18 +51,8 @@ impl Function for SleepFunction { "SleepFunction" } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - if !args[0].data_type_id().is_numeric() { - return Err(ErrorCode::BadArguments(format!( - "Illegal type {} of argument of function {}, expected numeric", - args[0].data_type_id(), - self.display_name - ))); - } - Ok(Int8Type::arc()) + fn return_type(&self) -> DataTypePtr { + Int8Type::arc() } fn eval( diff --git a/common/functions/src/scalars/others/to_type_name.rs b/common/functions/src/scalars/others/to_type_name.rs index cbd0a283e7ef..08dd3800b3ac 100644 --- a/common/functions/src/scalars/others/to_type_name.rs +++ b/common/functions/src/scalars/others/to_type_name.rs @@ -14,6 +14,7 @@ use std::fmt; +use common_datavalues::DataTypePtr; use common_datavalues::DataValue; use common_datavalues::StringType; use common_exception::Result; @@ -29,7 +30,7 @@ pub struct ToTypeNameFunction { } impl ToTypeNameFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(ToTypeNameFunction { _display_name: display_name.to_string(), })) @@ -50,11 +51,8 @@ impl Function for ToTypeNameFunction { "ToTypeNameFunction" } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( diff --git a/common/functions/src/scalars/semi_structureds/check_json.rs b/common/functions/src/scalars/semi_structureds/check_json.rs index e67268284f18..ab9bfa103f61 100644 --- a/common/functions/src/scalars/semi_structureds/check_json.rs +++ b/common/functions/src/scalars/semi_structureds/check_json.rs @@ -13,9 +13,7 @@ // limitations under the License. use std::fmt; -use std::sync::Arc; -use common_arrow::arrow::bitmap::Bitmap; use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; @@ -32,19 +30,15 @@ pub struct CheckJsonFunction { } impl CheckJsonFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(CheckJsonFunction { display_name: display_name.to_string(), })) } pub fn desc() -> FunctionDescription { - FunctionDescription::creator(Box::new(Self::try_create)).features( - FunctionFeatures::default() - .deterministic() - .monotonicity() - .num_arguments(1), - ) + FunctionDescription::creator(Box::new(Self::try_create)) + .features(FunctionFeatures::default().deterministic().num_arguments(1)) } } @@ -53,12 +47,8 @@ impl Function for CheckJsonFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if args[0].data_type_id() == TypeID::Null { - return Ok(NullType::arc()); - } - - Ok(Arc::new(NullableType::create(StringType::arc()))) + fn return_type(&self) -> DataTypePtr { + NullableType::arc(StringType::arc()) } fn eval( @@ -67,19 +57,8 @@ impl Function for CheckJsonFunction { input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - let data_type = remove_nullable(columns[0].field().data_type()); - let mut column = columns[0].column(); - let mut _all_null = false; - let mut source_valids: Option<&Bitmap> = None; - if column.is_nullable() { - (_all_null, source_valids) = column.validity(); - let nullable_column: &NullableColumn = Series::check_get(column)?; - column = nullable_column.inner(); - } - - if data_type.data_type_id() == TypeID::Null { - return NullType::arc().create_constant_column(&DataValue::Null, input_rows); - } + let data_type = columns[0].field().data_type(); + let column = columns[0].column(); let mut builder = NullableColumnBuilder::::with_capacity(input_rows); @@ -89,14 +68,7 @@ impl Function for CheckJsonFunction { } } else if data_type.data_type_id() == TypeID::String { let c: &StringColumn = Series::check_get(column)?; - for (i, v) in c.iter().enumerate() { - if let Some(source_valids) = source_valids { - if !source_valids.get_bit(i) { - builder.append_null(); - continue; - } - } - + for v in c.iter() { match std::str::from_utf8(v) { Ok(v) => match serde_json::from_str::(v) { Ok(_v) => builder.append_null(), diff --git a/common/functions/src/scalars/semi_structureds/get.rs b/common/functions/src/scalars/semi_structureds/get.rs index f66836303c68..46bc760dcbfa 100644 --- a/common/functions/src/scalars/semi_structureds/get.rs +++ b/common/functions/src/scalars/semi_structureds/get.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::fmt; -use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::ErrorCode; @@ -41,26 +40,7 @@ pub struct GetFunctionImpl { } impl GetFunctionImpl { - pub fn try_create(display_name: &str) -> Result> { - Ok(Box::new(GetFunctionImpl:: { - display_name: display_name.to_string(), - })) - } - - pub fn desc() -> FunctionDescription { - FunctionDescription::creator(Box::new(Self::try_create)) - .features(FunctionFeatures::default().deterministic().num_arguments(2)) - } -} - -impl Function - for GetFunctionImpl -{ - fn name(&self) -> &str { - &*self.display_name - } - - fn return_type(&self, args: &[&DataTypePtr]) -> Result { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { let data_type = args[0]; let path_type = args[1]; @@ -76,13 +56,32 @@ impl Function { return Err(ErrorCode::IllegalDataType(format!( "Invalid argument types for function '{}': ({:?}, {:?})", - self.display_name.to_uppercase(), + display_name.to_uppercase(), data_type, path_type ))); } - Ok(Arc::new(NullableType::create(VariantType::arc()))) + Ok(Box::new(GetFunctionImpl:: { + display_name: display_name.to_string(), + })) + } + + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create)) + .features(FunctionFeatures::default().deterministic().num_arguments(2)) + } +} + +impl Function + for GetFunctionImpl +{ + fn name(&self) -> &str { + &*self.display_name + } + + fn return_type(&self) -> DataTypePtr { + NullableType::arc(VariantType::arc()) } fn eval( diff --git a/common/functions/src/scalars/semi_structureds/parse_json.rs b/common/functions/src/scalars/semi_structureds/parse_json.rs index 60d02fc3b3de..21099054ba8e 100644 --- a/common/functions/src/scalars/semi_structureds/parse_json.rs +++ b/common/functions/src/scalars/semi_structureds/parse_json.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::fmt; -use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::ErrorCode; @@ -32,12 +31,33 @@ pub type ParseJsonFunction = ParseJsonFunctionImpl; #[derive(Clone)] pub struct ParseJsonFunctionImpl { display_name: String, + result_type: DataTypePtr, } impl ParseJsonFunctionImpl { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + let data_type = remove_nullable(args[0]); + if data_type.data_type_id() == TypeID::VariantArray + || data_type.data_type_id() == TypeID::VariantObject + { + return Err(ErrorCode::BadDataValueType(format!( + "Invalid argument types for function '{}': ({})", + display_name, + data_type.name() + ))); + } + + let result_type = if args[0].data_type_id() == TypeID::Null { + NullType::arc() + } else if args[0].is_nullable() || SUPPRESS_PARSE_ERROR { + NullableType::arc(VariantType::arc()) + } else { + VariantType::arc() + }; + Ok(Box::new(ParseJsonFunctionImpl:: { display_name: display_name.to_string(), + result_type, })) } @@ -57,20 +77,8 @@ impl Function for ParseJsonFunctionImpl Result { - if args[0].data_type_id() == TypeID::Null { - return Ok(NullType::arc()); - } - - if SUPPRESS_PARSE_ERROR { - // For invalid input, we suppress parse error and return null. So the return type must be nullable. - return Ok(Arc::new(NullableType::create(VariantType::arc()))); - } - - if args[0].is_nullable() { - return Ok(Arc::new(NullableType::create(VariantType::arc()))); - } - Ok(VariantType::arc()) + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( @@ -80,15 +88,7 @@ impl Function for ParseJsonFunctionImpl Result { let data_type = columns[0].field().data_type(); - if data_type.data_type_id() == TypeID::VariantArray - || data_type.data_type_id() == TypeID::VariantObject - { - return Err(ErrorCode::BadDataValueType(format!( - "Invalid argument types for function '{}': ({})", - self.display_name, - data_type.name() - ))); - } else if data_type.data_type_id() == TypeID::Null { + if data_type.data_type_id() == TypeID::Null { return NullType::arc().create_constant_column(&DataValue::Null, input_rows); } @@ -124,19 +124,7 @@ impl Function for ParseJsonFunctionImpl::with_capacity(input_rows); if data_type.data_type_id().is_numeric() diff --git a/common/functions/src/scalars/strings/bin.rs b/common/functions/src/scalars/strings/bin.rs index 0af6f810e106..703899d84e0f 100644 --- a/common/functions/src/scalars/strings/bin.rs +++ b/common/functions/src/scalars/strings/bin.rs @@ -18,8 +18,10 @@ use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_numeric; use crate::scalars::cast_column_field; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -29,7 +31,8 @@ pub struct BinFunction { } impl BinFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(BinFunction { _display_name: display_name.to_string(), })) @@ -46,15 +49,8 @@ impl Function for BinFunction { "bin" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_numeric() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer but got {}", - args[0].data_type_id() - ))); - } - - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -114,4 +110,3 @@ impl fmt::Display for BinFunction { write!(f, "BIN") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/char_.rs b/common/functions/src/scalars/strings/char_.rs index 77e68434a34c..1d271b095016 100644 --- a/common/functions/src/scalars/strings/char_.rs +++ b/common/functions/src/scalars/strings/char_.rs @@ -20,6 +20,7 @@ use common_exception::Result; use crate::scalars::assert_numeric; use crate::scalars::default_column_cast; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -29,7 +30,10 @@ pub struct CharFunction { } impl CharFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_numeric(*arg)?; + } Ok(Box::new(CharFunction { _display_name: display_name.to_string(), })) @@ -49,11 +53,8 @@ impl Function for CharFunction { "char" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_numeric(*arg)?; - } - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -104,4 +105,3 @@ impl fmt::Display for CharFunction { write!(f, "CHAR") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/concat.rs b/common/functions/src/scalars/strings/concat.rs index ea2d2a076cc0..0fc776c51ea6 100644 --- a/common/functions/src/scalars/strings/concat.rs +++ b/common/functions/src/scalars/strings/concat.rs @@ -19,6 +19,7 @@ use common_exception::Result; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -28,7 +29,10 @@ pub struct ConcatFunction { } impl ConcatFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_string(*arg)?; + } Ok(Box::new(ConcatFunction { _display_name: display_name.to_string(), })) @@ -48,11 +52,8 @@ impl Function for ConcatFunction { "concat" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_string(*arg)?; - } - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -87,4 +88,3 @@ impl fmt::Display for ConcatFunction { write!(f, "CONCAT") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/concat_ws.rs b/common/functions/src/scalars/strings/concat_ws.rs index 7b741ce28824..2583678c6a64 100644 --- a/common/functions/src/scalars/strings/concat_ws.rs +++ b/common/functions/src/scalars/strings/concat_ws.rs @@ -19,18 +19,38 @@ use common_exception::Result; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; #[derive(Clone)] pub struct ConcatWsFunction { _display_name: String, + result_type: DataTypePtr, } impl ConcatWsFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + let result_type = if args[0].is_null() { + NullType::arc() + } else { + for arg in args { + let arg = remove_nullable(*arg); + if !arg.is_null() { + assert_string(&arg)?; + } + } + + let dt = Vu8::to_data_type(); + match args[0].is_nullable() { + true => wrap_nullable(&dt), + false => dt, + } + }; + Ok(Box::new(ConcatWsFunction { _display_name: display_name.to_string(), + result_type, })) } @@ -144,23 +164,8 @@ impl Function for ConcatWsFunction { "concat_ws" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if args[0].is_null() { - return Ok(NullType::arc()); - } - - for arg in args { - let arg = remove_nullable(*arg); - if !arg.is_null() { - assert_string(&arg)?; - } - } - - let dt = Vu8::to_data_type(); - match args[0].is_nullable() { - true => Ok(wrap_nullable(&dt)), - false => Ok(dt), - } + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( @@ -169,11 +174,12 @@ impl Function for ConcatWsFunction { input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - let seperator = &columns[0]; - if seperator.data_type().is_null() { + if self.result_type.is_null() { return Ok(NullColumn::new(input_rows).arc()); } + let seperator = &columns[0]; + // remove other null columns let cols: Vec = columns[1..] .iter() @@ -193,7 +199,7 @@ impl Function for ConcatWsFunction { ); } - match columns[0].data_type().is_nullable() { + match self.result_type.is_nullable() { false => Self::concat_column_nonull(&columns[0], &cols, input_rows), true => Self::concat_column_null(&columns[0], &cols, input_rows), } @@ -205,4 +211,3 @@ impl fmt::Display for ConcatWsFunction { write!(f, "CONCAT_WS") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/elt.rs b/common/functions/src/scalars/strings/elt.rs index cd5af75cb60c..d3bc0e29fb10 100644 --- a/common/functions/src/scalars/strings/elt.rs +++ b/common/functions/src/scalars/strings/elt.rs @@ -23,21 +23,44 @@ use num_traits::AsPrimitive; use crate::scalars::assert_numeric; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; #[derive(Clone)] pub struct EltFunction { display_name: String, + result_type: DataTypePtr, } //MySQL ELT() returns the string at the index number specified in the list of arguments. The first argument indicates the index of the string to be retrieved from the list of arguments. // Note: According to Wikipedia ELT stands for Extract, Load, Transform (ELT), a data manipulation process impl EltFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + let result_type = if args[0].is_null() { + NullType::arc() + } else { + let arg = remove_nullable(args[0]); + assert_numeric(&arg)?; + + for arg in args[1..].iter() { + let arg = remove_nullable(*arg); + if !arg.is_null() { + assert_string(&arg)?; + } + } + + let dt = Vu8::to_data_type(); + match args.iter().any(|f| f.is_nullable()) { + true => wrap_nullable(&dt), + false => dt, + } + }; + Ok(Box::new(EltFunction { display_name: display_name.to_string(), + result_type, })) } @@ -56,26 +79,8 @@ impl Function for EltFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if args[0].is_null() { - return Ok(NullType::arc()); - } - - let arg = remove_nullable(args[0]); - assert_numeric(&arg)?; - - for arg in args[1..].iter() { - let arg = remove_nullable(*arg); - if !arg.is_null() { - assert_string(&arg)?; - } - } - - let dt = Vu8::to_data_type(); - match args.iter().any(|f| f.is_nullable()) { - true => Ok(wrap_nullable(&dt)), - false => Ok(dt), - } + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( @@ -84,7 +89,7 @@ impl Function for EltFunction { input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - if columns[0].data_type().is_null() { + if self.result_type.is_null() { return Ok(NullColumn::new(input_rows).arc()); } let nullable = columns.iter().any(|c| c.data_type().is_nullable()); @@ -174,4 +179,3 @@ impl fmt::Display for EltFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/export_set.rs b/common/functions/src/scalars/strings/export_set.rs index c82cb14965cc..b8eca8e1715f 100644 --- a/common/functions/src/scalars/strings/export_set.rs +++ b/common/functions/src/scalars/strings/export_set.rs @@ -23,6 +23,7 @@ use crate::scalars::assert_numeric; use crate::scalars::assert_string; use crate::scalars::cast_with_type; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; use crate::scalars::DEFAULT_CAST_OPTIONS; @@ -33,7 +34,19 @@ pub struct ExportSetFunction { } impl ExportSetFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; + assert_string(args[1])?; + assert_string(args[2])?; + + if args.len() >= 4 { + assert_string(args[3])?; + } + + if args.len() >= 5 { + assert_numeric(args[4])?; + } + Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -53,20 +66,8 @@ impl Function for ExportSetFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - assert_string(args[1])?; - assert_string(args[2])?; - - if args.len() >= 4 { - assert_string(args[3])?; - } - - if args.len() >= 5 { - assert_numeric(args[4])?; - } - - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -167,4 +168,3 @@ fn export_set<'a>( } } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/field.rs b/common/functions/src/scalars/strings/field.rs index 0ac3f077d901..98cc156c15fa 100644 --- a/common/functions/src/scalars/strings/field.rs +++ b/common/functions/src/scalars/strings/field.rs @@ -19,6 +19,7 @@ use common_exception::Result; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -28,7 +29,10 @@ pub struct FieldFunction { } impl FieldFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_string(*arg)?; + } Ok(Box::new(FieldFunction { display_name: display_name.to_string(), })) @@ -48,11 +52,8 @@ impl Function for FieldFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_string(*arg)?; - } - Ok(u64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + u64::to_data_type() } fn eval( @@ -95,4 +96,3 @@ impl fmt::Display for FieldFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/find_in_set.rs b/common/functions/src/scalars/strings/find_in_set.rs index 0f5fa91909c0..be32e5f27014 100644 --- a/common/functions/src/scalars/strings/find_in_set.rs +++ b/common/functions/src/scalars/strings/find_in_set.rs @@ -21,6 +21,7 @@ use crate::scalars::assert_string; use crate::scalars::scalar_binary_op; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -30,7 +31,9 @@ pub struct FindInSetFunction { } impl FindInSetFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + assert_string(args[1])?; Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -47,10 +50,8 @@ impl Function for FindInSetFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_string(args[0])?; - assert_string(args[1])?; - Ok(u64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + u64::to_data_type() } fn eval( @@ -90,4 +91,3 @@ fn find_in_set(str: &[u8], list: &[u8], _ctx: &mut EvalContext) -> u64 { } 0 } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/format.rs b/common/functions/src/scalars/strings/format.rs index 320d096e3659..30a0cdd89c6a 100644 --- a/common/functions/src/scalars/strings/format.rs +++ b/common/functions/src/scalars/strings/format.rs @@ -27,6 +27,7 @@ use crate::scalars::assert_string; use crate::scalars::scalar_binary_op; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -41,7 +42,12 @@ pub struct FormatFunction { // Formats the number X to a format like '#,###,###.##', rounded to D decimal places, and returns the result as a string. // If D is 0, the result has no decimal point or fractional part. impl FormatFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; + assert_numeric(args[1])?; + if args.len() >= 3 { + assert_string(args[2])?; + } Ok(Box::new(FormatFunction { _display_name: display_name.to_string(), })) @@ -61,13 +67,8 @@ impl Function for FormatFunction { "format" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - assert_numeric(args[1])?; - if args.len() >= 3 { - assert_string(args[2])?; - } - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -121,4 +122,3 @@ impl fmt::Display for FormatFunction { write!(f, "FORMAT") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/hex.rs b/common/functions/src/scalars/strings/hex.rs index 80d06a47791f..f8500af17149 100644 --- a/common/functions/src/scalars/strings/hex.rs +++ b/common/functions/src/scalars/strings/hex.rs @@ -22,6 +22,7 @@ use common_exception::Result; use crate::scalars::cast_column_field; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -31,7 +32,14 @@ pub struct HexFunction { } impl HexFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + if !args[0].data_type_id().is_numeric() && !args[0].data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string but got {}", + args[0].data_type_id() + ))); + } + Ok(Box::new(HexFunction { _display_name: display_name.to_string(), })) @@ -48,15 +56,8 @@ impl Function for HexFunction { "hex" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_numeric() && !args[0].data_type_id().is_string() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string but got {}", - args[0].data_type_id() - ))); - } - - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -107,4 +108,3 @@ impl fmt::Display for HexFunction { write!(f, "HEX") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/insert.rs b/common/functions/src/scalars/strings/insert.rs index cd07188d1b84..a5d26dadc8ba 100644 --- a/common/functions/src/scalars/strings/insert.rs +++ b/common/functions/src/scalars/strings/insert.rs @@ -22,6 +22,7 @@ use num_traits::AsPrimitive; use crate::scalars::assert_numeric; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -56,7 +57,12 @@ pub struct InsertFunction { } impl InsertFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + assert_numeric(args[1])?; + assert_numeric(args[2])?; + assert_string(args[3])?; + Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -73,13 +79,8 @@ impl Function for InsertFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_string(args[0])?; - assert_numeric(args[1])?; - assert_numeric(args[2])?; - assert_string(args[3])?; - - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -127,4 +128,3 @@ impl fmt::Display for InsertFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/leftright.rs b/common/functions/src/scalars/strings/leftright.rs index 15610b1db421..67f36265fc87 100644 --- a/common/functions/src/scalars/strings/leftright.rs +++ b/common/functions/src/scalars/strings/leftright.rs @@ -25,6 +25,7 @@ use crate::scalars::assert_string; use crate::scalars::scalar_binary_op_ref; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -57,7 +58,9 @@ pub struct LeftRightFunction { } impl LeftRightFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + assert_numeric(args[1])?; Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -74,10 +77,8 @@ impl Function for LeftRightFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_string(args[0])?; - assert_numeric(args[1])?; - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -112,4 +113,3 @@ impl fmt::Display for LeftRightFunction { f.write_str(&self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/locate.rs b/common/functions/src/scalars/strings/locate.rs index 20a130de8f57..e8230a3e8613 100644 --- a/common/functions/src/scalars/strings/locate.rs +++ b/common/functions/src/scalars/strings/locate.rs @@ -21,6 +21,7 @@ use crate::scalars::assert_numeric; use crate::scalars::assert_string; use crate::scalars::default_column_cast; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -38,7 +39,12 @@ pub struct LocatingFunction { } impl LocatingFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + assert_string(args[1])?; + if args.len() > 2 { + assert_numeric(args[2])?; + } Ok(Box::new(LocatingFunction:: { display_name: display_name.to_string(), })) @@ -61,13 +67,8 @@ impl Function for LocatingFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_string(args[0])?; - assert_string(args[1])?; - if args.len() > 2 { - assert_numeric(args[2])?; - } - Ok(u64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + u64::to_data_type() } fn eval( @@ -125,4 +126,3 @@ fn find_at(str: &[u8], substr: &[u8], pos: u64) -> u64 { 0_u64 } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/oct.rs b/common/functions/src/scalars/strings/oct.rs index 4c22922f53e3..522a7959e3b9 100644 --- a/common/functions/src/scalars/strings/oct.rs +++ b/common/functions/src/scalars/strings/oct.rs @@ -16,11 +16,12 @@ use std::cmp::Ordering; use std::fmt; use common_datavalues::prelude::*; -use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_numeric; use crate::scalars::cast_column_field; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -56,7 +57,8 @@ pub struct OctFunction { } impl OctFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(OctFunction { _display_name: display_name.to_string(), })) @@ -73,15 +75,8 @@ impl Function for OctFunction { "oct" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_numeric() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer but got {}", - args[0].data_type_id() - ))); - } - - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -117,4 +112,3 @@ impl fmt::Display for OctFunction { write!(f, "OCT") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/pad.rs b/common/functions/src/scalars/strings/pad.rs index 003aa5f1340c..e9e506c33bdc 100644 --- a/common/functions/src/scalars/strings/pad.rs +++ b/common/functions/src/scalars/strings/pad.rs @@ -23,6 +23,7 @@ use num_traits::AsPrimitive; use crate::scalars::assert_numeric; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -95,7 +96,10 @@ pub struct PadFunction { } impl PadFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + assert_numeric(args[1])?; + assert_string(args[2])?; Ok(Box::new(Self { display_name: display_name.to_string(), _marker: PhantomData, @@ -113,11 +117,8 @@ impl Function for PadFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_string(args[0])?; - assert_numeric(args[1])?; - assert_string(args[2])?; - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -149,4 +150,3 @@ impl fmt::Display for PadFunction { f.write_str(&self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/regexp_instr.rs b/common/functions/src/scalars/strings/regexp_instr.rs index 5dd2388346da..8ed5f27e91cb 100644 --- a/common/functions/src/scalars/strings/regexp_instr.rs +++ b/common/functions/src/scalars/strings/regexp_instr.rs @@ -27,6 +27,7 @@ use crate::scalars::assert_string; use crate::scalars::cast_column_field; use crate::scalars::strings::regexp_like::build_regexp_from_pattern; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -36,7 +37,18 @@ pub struct RegexpInStrFunction { } impl RegexpInStrFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for (i, arg) in args.iter().enumerate() { + if i < 2 || i == 5 { + assert_string(*arg)?; + } else if !arg.data_type_id().is_integer() && !arg.data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string or null, but got {}", + args[i].data_type_id() + ))); + } + } + Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -56,23 +68,10 @@ impl Function for RegexpInStrFunction { &self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for (i, arg) in args.iter().enumerate() { - if i < 2 || i == 5 { - assert_string(*arg)?; - } else if !arg.data_type_id().is_integer() - && !arg.data_type_id().is_string() - && !arg.data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string or null, but got {}", - args[i].data_type_id() - ))); - } - } - - Ok(u64::to_data_type()) + fn return_type(&self) -> DataTypePtr { + u64::to_data_type() } + // Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-instr fn eval( &self, @@ -278,4 +277,3 @@ impl fmt::Display for RegexpInStrFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/regexp_like.rs b/common/functions/src/scalars/strings/regexp_like.rs index 35852fca716b..f57d873fe157 100644 --- a/common/functions/src/scalars/strings/regexp_like.rs +++ b/common/functions/src/scalars/strings/regexp_like.rs @@ -26,6 +26,7 @@ use regex::bytes::RegexBuilder as BytesRegexBuilder; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -35,7 +36,11 @@ pub struct RegexpLikeFunction { } impl RegexpLikeFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_string(*arg)?; + } + Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -56,12 +61,8 @@ impl Function for RegexpLikeFunction { &self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_string(*arg)?; - } - - Ok(BooleanType::arc()) + fn return_type(&self) -> DataTypePtr { + BooleanType::arc() } // Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-like fn eval( @@ -245,4 +246,3 @@ pub fn build_regexp_from_pattern( )) }) } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/regexp_substr.rs b/common/functions/src/scalars/strings/regexp_substr.rs index fa4ed2649a82..b6a6fe2ae486 100644 --- a/common/functions/src/scalars/strings/regexp_substr.rs +++ b/common/functions/src/scalars/strings/regexp_substr.rs @@ -14,7 +14,6 @@ use std::collections::HashMap; use std::fmt; -use std::sync::Arc; use bstr::ByteSlice; use common_datavalues::prelude::*; @@ -27,6 +26,7 @@ use crate::scalars::assert_string; use crate::scalars::cast_column_field; use crate::scalars::strings::regexp_like::build_regexp_from_pattern; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -36,7 +36,18 @@ pub struct RegexpSubStrFunction { } impl RegexpSubStrFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for (i, arg) in args.iter().enumerate() { + if i < 2 || i == 4 { + assert_string(*arg)?; + } else if !arg.data_type_id().is_integer() && !arg.data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string or null, but got {}", + args[i].data_type_id() + ))); + } + } + Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -56,22 +67,8 @@ impl Function for RegexpSubStrFunction { &self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for (i, arg) in args.iter().enumerate() { - if i < 2 || i == 4 { - assert_string(*arg)?; - } else if !arg.data_type_id().is_integer() - && !arg.data_type_id().is_string() - && !arg.data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string or null, but got {}", - args[i].data_type_id() - ))); - } - } - - Ok(Arc::new(NullableType::create(StringType::arc()))) + fn return_type(&self) -> DataTypePtr { + NullableType::arc(StringType::arc()) } // Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-substr @@ -247,4 +244,3 @@ impl fmt::Display for RegexpSubStrFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/repeat.rs b/common/functions/src/scalars/strings/repeat.rs index d34176b22d5e..ff3e645125a2 100644 --- a/common/functions/src/scalars/strings/repeat.rs +++ b/common/functions/src/scalars/strings/repeat.rs @@ -20,6 +20,7 @@ use common_exception::Result; use crate::scalars::cast_column_field; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -31,7 +32,21 @@ pub struct RepeatFunction { } impl RepeatFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + if !args[0].data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected parameter 1 is string, but got {}", + args[0].data_type_id() + ))); + } + + if !args[1].data_type_id().is_unsigned_integer() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected parameter 2 is unsigned integer or null, but got {}", + args[1].data_type_id() + ))); + } + Ok(Box::new(RepeatFunction { _display_name: display_name.to_string(), })) @@ -48,22 +63,8 @@ impl Function for RepeatFunction { "repeat" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_string() && !args[0].data_type_id().is_null() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected parameter 1 is string, but got {}", - args[0].data_type_id() - ))); - } - - if !args[1].data_type_id().is_unsigned_integer() && !args[1].data_type_id().is_null() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected parameter 2 is unsigned integer or null, but got {}", - args[1].data_type_id() - ))); - } - - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -72,8 +73,7 @@ impl Function for RepeatFunction { input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - let col1 = cast_column_field(&columns[0], &StringType::arc())?; - let col1_viewer = Vu8::try_create_viewer(&col1)?; + let col1_viewer = Vu8::try_create_viewer(columns[0].column())?; let col2 = cast_column_field(&columns[1], &UInt64Type::arc())?; let col2_viewer = u64::try_create_viewer(&col2)?; @@ -106,4 +106,3 @@ fn repeat(string: impl AsRef<[u8]>, times: u64) -> Result> { } Ok(string.as_ref().repeat(times as usize)) } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/replace.rs b/common/functions/src/scalars/strings/replace.rs index 3b2581a36312..584d1f773786 100644 --- a/common/functions/src/scalars/strings/replace.rs +++ b/common/functions/src/scalars/strings/replace.rs @@ -19,6 +19,7 @@ use common_exception::Result; use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -49,7 +50,10 @@ pub struct ReplaceFunction { } impl ReplaceFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_string(*arg)?; + } Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -66,11 +70,8 @@ impl Function for ReplaceFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_string(*arg)?; - } - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -106,4 +107,3 @@ impl fmt::Display for ReplaceFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/space.rs b/common/functions/src/scalars/strings/space.rs index 9d17bb563545..d18f6e8ee840 100644 --- a/common/functions/src/scalars/strings/space.rs +++ b/common/functions/src/scalars/strings/space.rs @@ -24,6 +24,7 @@ use crate::scalars::assert_numeric; use crate::scalars::scalar_unary_op; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -34,7 +35,8 @@ pub struct SpaceFunction { } impl SpaceFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_numeric(args[0])?; Ok(Box::new(Self { display_name: display_name.to_string(), })) @@ -51,9 +53,8 @@ impl Function for SpaceFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - assert_numeric(args[0])?; - Ok(Vu8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + Vu8::to_data_type() } fn eval( @@ -78,4 +79,3 @@ impl fmt::Display for SpaceFunction { write!(f, "{}", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/strcmp.rs b/common/functions/src/scalars/strings/strcmp.rs index 59c9c558fb2a..5c897c7b0088 100644 --- a/common/functions/src/scalars/strings/strcmp.rs +++ b/common/functions/src/scalars/strings/strcmp.rs @@ -23,6 +23,7 @@ use crate::scalars::assert_string; use crate::scalars::scalar_binary_op; use crate::scalars::EvalContext; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -32,7 +33,10 @@ pub struct StrcmpFunction { } impl StrcmpFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + for arg in args { + assert_string(*arg)?; + } Ok(Box::new(StrcmpFunction { display_name: display_name.to_string(), })) @@ -49,11 +53,8 @@ impl Function for StrcmpFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - for arg in args { - assert_string(*arg)?; - } - Ok(i8::to_data_type()) + fn return_type(&self) -> DataTypePtr { + i8::to_data_type() } fn eval( @@ -102,4 +103,3 @@ fn strcmp(s1: &[u8], s2: &[u8], _ctx: &mut EvalContext) -> i8 { Ordering::Less => -1, } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/string2number.rs b/common/functions/src/scalars/strings/string2number.rs index 1295a7f60ef5..9cd7af9970c3 100644 --- a/common/functions/src/scalars/strings/string2number.rs +++ b/common/functions/src/scalars/strings/string2number.rs @@ -16,10 +16,11 @@ use std::marker::PhantomData; use std::sync::Arc; use common_datavalues::prelude::*; -use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -42,7 +43,9 @@ where T: NumberOperator, R: PrimitiveType + Clone + ToDataType, { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + Ok(Box::new(Self { display_name: display_name.to_string(), t: PhantomData, @@ -76,19 +79,8 @@ where &self.display_name } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - // We allow string AND null as input - if args[0].data_type_id().is_string() { - Ok(R::to_data_type()) - } else { - Err(ErrorCode::IllegalDataType(format!( - "Expected string, numeric or null, but got {:?}", - args[0] - ))) - } + fn return_type(&self) -> DataTypePtr { + R::to_data_type() } fn eval( @@ -114,4 +106,3 @@ impl fmt::Display for String2NumberFunction { write!(f, "{}()", self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/string2string.rs b/common/functions/src/scalars/strings/string2string.rs index 2c8f69d7c3b6..585cd3d7e0cc 100644 --- a/common/functions/src/scalars/strings/string2string.rs +++ b/common/functions/src/scalars/strings/string2string.rs @@ -17,13 +17,12 @@ use std::sync::Arc; use common_datavalues::prelude::*; use common_datavalues::StringType; -use common_datavalues::TypeID; -use common_exception::ErrorCode; use common_exception::Result; +use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; -// use common_tracing::tracing; use crate::scalars::FunctionFeatures; pub trait StringOperator: Send + Sync + Clone + Default + 'static { @@ -44,7 +43,9 @@ pub struct String2StringFunction { } impl String2StringFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + Ok(Box::new(Self { display_name: display_name.to_string(), _marker: PhantomData, @@ -62,17 +63,8 @@ impl Function for String2StringFunction { &self.display_name } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - if args[0].data_type_id() != TypeID::String { - return Err(ErrorCode::IllegalDataType(format!( - "Expected string arg, but got {:?}", - args[0] - ))); - } - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -96,4 +88,3 @@ impl fmt::Display for String2StringFunction { f.write_str(&self.display_name) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/substring.rs b/common/functions/src/scalars/strings/substring.rs index cbd6117ab336..179a7548359b 100644 --- a/common/functions/src/scalars/strings/substring.rs +++ b/common/functions/src/scalars/strings/substring.rs @@ -19,8 +19,10 @@ use common_exception::ErrorCode; use common_exception::Result; use itertools::izip; +use crate::scalars::assert_string; use crate::scalars::cast_column_field; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -30,7 +32,26 @@ pub struct SubstringFunction { } impl SubstringFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; + + if !args[1].data_type_id().is_integer() && !args[1].data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string or null, but got {}", + args[1].data_type_id() + ))); + } + + if args.len() > 2 + && !args[2].data_type_id().is_integer() + && !args[2].data_type_id().is_string() + { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string or null, but got {}", + args[2].data_type_id() + ))); + } + Ok(Box::new(SubstringFunction { display_name: display_name.to_string(), })) @@ -50,36 +71,8 @@ impl Function for SubstringFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_string() && !args[0].data_type_id().is_null() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected string or null, but got {}", - args[0].data_type_id() - ))); - } - - if !args[1].data_type_id().is_integer() - && !args[1].data_type_id().is_string() - && !args[1].data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string or null, but got {}", - args[1].data_type_id() - ))); - } - - if args.len() > 2 - && !args[2].data_type_id().is_integer() - && !args[2].data_type_id().is_string() - && !args[2].data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string or null, but got {}", - args[2].data_type_id() - ))); - } - - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -156,4 +149,3 @@ fn substr<'a>(str: &'a [u8], pos: &i64, len: &u64) -> &'a [u8] { } &str[0..0] } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/substring_index.rs b/common/functions/src/scalars/strings/substring_index.rs index 594a885bcb31..b262c396576b 100644 --- a/common/functions/src/scalars/strings/substring_index.rs +++ b/common/functions/src/scalars/strings/substring_index.rs @@ -21,6 +21,7 @@ use itertools::izip; use crate::scalars::cast_column_field; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -30,7 +31,25 @@ pub struct SubstringIndexFunction { } impl SubstringIndexFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + if !args[0].data_type_id().is_numeric() && !args[0].data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected string or null, but got {}", + args[0].data_type_id() + ))); + } + if !args[1].data_type_id().is_numeric() && !args[1].data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string or null, but got {}", + args[1].data_type_id() + ))); + } + if !args[2].data_type_id().is_integer() && !args[2].data_type_id().is_string() { + return Err(ErrorCode::IllegalDataType(format!( + "Expected integer or string or null, but got {}", + args[2].data_type_id() + ))); + } Ok(Box::new(SubstringIndexFunction { display_name: display_name.to_string(), })) @@ -47,35 +66,8 @@ impl Function for SubstringIndexFunction { &*self.display_name } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_numeric() - && !args[0].data_type_id().is_string() - && !args[0].data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected string or null, but got {}", - args[0].data_type_id() - ))); - } - if !args[1].data_type_id().is_numeric() - && !args[1].data_type_id().is_string() - && !args[1].data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string or null, but got {}", - args[1].data_type_id() - ))); - } - if !args[2].data_type_id().is_integer() - && !args[2].data_type_id().is_string() - && !args[2].data_type_id().is_null() - { - return Err(ErrorCode::IllegalDataType(format!( - "Expected integer or string or null, but got {}", - args[2].data_type_id() - ))); - } - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -142,4 +134,3 @@ fn substring_index<'a>(str: &'a [u8], delim: &'a [u8], count: &i64) -> &'a [u8] } str } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/strings/unhex.rs b/common/functions/src/scalars/strings/unhex.rs index 9cddec791d0f..ea960e4c4e92 100644 --- a/common/functions/src/scalars/strings/unhex.rs +++ b/common/functions/src/scalars/strings/unhex.rs @@ -18,8 +18,9 @@ use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; -use crate::scalars::cast_column_field; +use crate::scalars::assert_string; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -29,7 +30,8 @@ pub struct UnhexFunction { } impl UnhexFunction { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + assert_string(args[0])?; Ok(Box::new(UnhexFunction { _display_name: display_name.to_string(), })) @@ -46,15 +48,8 @@ impl Function for UnhexFunction { "unhex" } - fn return_type(&self, args: &[&DataTypePtr]) -> Result { - if !args[0].data_type_id().is_string() && !args[0].data_type_id().is_null() { - return Err(ErrorCode::IllegalDataType(format!( - "Expected string or null, but got {}", - args[0].data_type_id() - ))); - } - - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -65,8 +60,11 @@ impl Function for UnhexFunction { ) -> Result { const BUFFER_SIZE: usize = 32; - let col = cast_column_field(&columns[0], &StringType::arc())?; - let col = col.as_any().downcast_ref::().unwrap(); + let col = columns[0] + .column() + .as_any() + .downcast_ref::() + .unwrap(); let mut builder: ColumnBuilder = ColumnBuilder::with_capacity(input_rows); @@ -103,4 +101,3 @@ impl fmt::Display for UnhexFunction { write!(f, "UNHEX") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/tuples/tuple.rs b/common/functions/src/scalars/tuples/tuple.rs index 8aa6c49892c0..0cc57f9a03ac 100644 --- a/common/functions/src/scalars/tuples/tuple.rs +++ b/common/functions/src/scalars/tuples/tuple.rs @@ -15,23 +15,36 @@ use std::fmt; use std::sync::Arc; +use common_datavalues::DataTypePtr; use common_datavalues::StructColumn; use common_datavalues::StructType; use common_exception::Result; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; #[derive(Clone)] pub struct TupleFunction { _display_name: String, + result_type: DataTypePtr, } impl TupleFunction { - pub fn try_create_func(_display_name: &str) -> Result> { + pub fn try_create_func( + _display_name: &str, + args: &[&DataTypePtr], + ) -> Result> { + let names = (0..args.len()) + .map(|i| format!("item_{}", i)) + .collect::>(); + let types = args.iter().map(|x| (*x).clone()).collect::>(); + let result_type = Arc::new(StructType::create(names, types)); + Ok(Box::new(TupleFunction { _display_name: "tuple".to_string(), + result_type, })) } @@ -50,16 +63,8 @@ impl Function for TupleFunction { "TupleFunction" } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - let names = (0..args.len()) - .map(|i| format!("item_{}", i)) - .collect::>(); - let types = args.iter().map(|x| (*x).clone()).collect::>(); - let t = Arc::new(StructType::create(names, types)); - Ok(t) + fn return_type(&self) -> DataTypePtr { + self.result_type.clone() } fn eval( @@ -68,21 +73,11 @@ impl Function for TupleFunction { _input_rows: usize, _func_ctx: FunctionContext, ) -> Result { - let mut cols = vec![]; - let mut types = vec![]; - - let names = (0..columns.len()) - .map(|i| format!("item_{}", i)) + let cols = columns + .iter() + .map(|v| v.column().clone()) .collect::>(); - - for c in columns { - cols.push(c.column().clone()); - types.push(c.data_type().clone()); - } - - let t = Arc::new(StructType::create(names, types)); - - let arr: StructColumn = StructColumn::from_data(cols, t); + let arr: StructColumn = StructColumn::from_data(cols, self.result_type.clone()); Ok(Arc::new(arr)) } } @@ -92,4 +87,3 @@ impl std::fmt::Display for TupleFunction { write!(f, "TUPLE") } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/uuids/uuid_creator.rs b/common/functions/src/scalars/uuids/uuid_creator.rs index 61b7893c470e..82f73f2d769c 100644 --- a/common/functions/src/scalars/uuids/uuid_creator.rs +++ b/common/functions/src/scalars/uuids/uuid_creator.rs @@ -17,6 +17,7 @@ use std::marker::PhantomData; use common_datavalues::Column; use common_datavalues::ConstColumn; +use common_datavalues::DataTypePtr; use common_datavalues::NewColumn; use common_datavalues::StringColumn; use common_datavalues::StringType; @@ -24,6 +25,7 @@ use common_exception::Result; use uuid::Uuid; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -39,7 +41,7 @@ pub struct UUIDCreatorFunction { impl UUIDCreatorFunction where T: UUIDCreator + Clone + Sync + Send + 'static { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result> { Ok(Box::new(UUIDCreatorFunction:: { display_name: display_name.to_string(), t: PhantomData, @@ -87,11 +89,8 @@ where T: UUIDCreator + Clone + Sync + Send + 'static self.display_name.as_str() } - fn return_type( - &self, - _args: &[&common_datavalues::DataTypePtr], - ) -> Result { - Ok(StringType::arc()) + fn return_type(&self) -> DataTypePtr { + StringType::arc() } fn eval( @@ -106,4 +105,3 @@ where T: UUIDCreator + Clone + Sync + Send + 'static Ok(ConstColumn::new(col.arc(), input_rows).arc()) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/src/scalars/uuids/uuid_verifier.rs b/common/functions/src/scalars/uuids/uuid_verifier.rs index 96733c39c995..cd5877363230 100644 --- a/common/functions/src/scalars/uuids/uuid_verifier.rs +++ b/common/functions/src/scalars/uuids/uuid_verifier.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use common_datavalues::BooleanColumn; use common_datavalues::BooleanType; +use common_datavalues::DataTypePtr; use common_datavalues::Scalar; use common_datavalues::ScalarColumn; use common_datavalues::ScalarViewer; @@ -29,6 +30,7 @@ use common_exception::Result; use uuid::Uuid; use crate::scalars::Function; +use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; use crate::scalars::FunctionFeatures; @@ -44,7 +46,14 @@ pub struct UUIDVerifierFunction { impl UUIDVerifierFunction where T: UUIDVerifier + Clone + Sync + Send + 'static { - pub fn try_create(display_name: &str) -> Result> { + pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result> { + if args[0].data_type_id() != TypeID::String && args[0].data_type_id() != TypeID::Null { + return Err(ErrorCode::IllegalDataType(format!( + "Expected string or null, but got {:?}", + args[0] + ))); + } + Ok(Box::new(UUIDVerifierFunction:: { display_name: display_name.to_string(), t: PhantomData, @@ -105,18 +114,8 @@ where T: UUIDVerifier + Clone + Sync + Send + 'static self.display_name.as_str() } - fn return_type( - &self, - args: &[&common_datavalues::DataTypePtr], - ) -> Result { - if args[0].data_type_id() != TypeID::String && args[0].data_type_id() != TypeID::Null { - return Err(ErrorCode::IllegalDataType(format!( - "Expected string or null, but got {:?}", - args[0] - ))); - } - - Ok(BooleanType::arc()) + fn return_type(&self) -> DataTypePtr { + BooleanType::arc() } fn eval( @@ -145,4 +144,3 @@ where T: UUIDVerifier + Clone + Sync + Send + 'static Ok(Arc::new(result_column)) } } -use crate::scalars::FunctionContext; diff --git a/common/functions/tests/it/scalars/arithmetics.rs b/common/functions/tests/it/scalars/arithmetics.rs index 148dc766447e..e83ec9083a16 100644 --- a/common/functions/tests/it/scalars/arithmetics.rs +++ b/common/functions/tests/it/scalars/arithmetics.rs @@ -15,102 +15,82 @@ use common_datavalues::chrono; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use super::scalar_function2_test::test_scalar_functions; -use super::scalar_function2_test::ScalarFunctionTest; +use super::scalar_function_test::test_scalar_functions; +use super::scalar_function_test::test_scalar_functions_with_type; +use super::scalar_function_test::ScalarFunctionTest; +use super::scalar_function_test::ScalarFunctionWithFieldTest; #[test] fn test_arithmetic_function() -> Result<()> { let tests = vec![ - ( - ArithmeticPlusFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "add-int64-passed", - columns: vec![ - Series::from_data(vec![4i64, 3, 2, 1]), - Series::from_data(vec![1i64, 2, 3, 4]), - ], - expect: Series::from_data(vec![5i64, 5, 5, 5]), - error: "", - }, - ), - ( - ArithmeticPlusFunction::try_create_func("", &[&Int16Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "add-diff-passed", - columns: vec![ - Series::from_data(vec![1i16, 2, 3, 4]), - Series::from_data(vec![1i64, 2, 3, 4]), - ], - expect: Series::from_data(vec![2i64, 4, 6, 8]), - error: "", - }, - ), - ( - ArithmeticMinusFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "sub-int64-passed", - columns: vec![ - Series::from_data(vec![4i64, 3, 2]), - Series::from_data(vec![1i64, 2, 3]), - ], - expect: Series::from_data(vec![3i64, 1, -1]), - error: "", - }, - ), - ( - ArithmeticMulFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "mul-int64-passed", - columns: vec![ - Series::from_data(vec![4i64, 3, 2]), - Series::from_data(vec![1i64, 2, 3]), - ], - expect: Series::from_data(vec![4i64, 6, 6]), - error: "", - }, - ), - ( - ArithmeticDivFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "div-int64-passed", - columns: vec![ - Series::from_data(vec![4i64, 3, 2]), - Series::from_data(vec![1i64, 2, 3]), - ], - expect: Series::from_data(vec![4.0, 1.5, 0.6666666666666666]), - error: "", - }, - ), - ( - ArithmeticIntDivFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "intdiv-int64-passed", - columns: vec![ - Series::from_data(vec![4i64, 3, 2]), - Series::from_data(vec![1i64, 2, 3]), - ], - expect: Series::from_data(vec![4i64, 1, 0]), - error: "", - }, - ), - ( - ArithmeticModuloFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - ScalarFunctionTest { - name: "mod-int64-passed", - columns: vec![ - Series::from_data(vec![4i64, 3, 2]), - Series::from_data(vec![1i64, 2, 3]), - ], - expect: Series::from_data(vec![0i64, 1, 2]), - error: "", - }, - ), + ("+", ScalarFunctionTest { + name: "add-int64-passed", + columns: vec![ + Series::from_data(vec![4i64, 3, 2, 1]), + Series::from_data(vec![1i64, 2, 3, 4]), + ], + expect: Series::from_data(vec![5i64, 5, 5, 5]), + error: "", + }), + ("plus", ScalarFunctionTest { + name: "add-diff-passed", + columns: vec![ + Series::from_data(vec![1i16, 2, 3, 4]), + Series::from_data(vec![1i64, 2, 3, 4]), + ], + expect: Series::from_data(vec![2i64, 4, 6, 8]), + error: "", + }), + ("-", ScalarFunctionTest { + name: "sub-int64-passed", + columns: vec![ + Series::from_data(vec![4i64, 3, 2]), + Series::from_data(vec![1i64, 2, 3]), + ], + expect: Series::from_data(vec![3i64, 1, -1]), + error: "", + }), + ("multiply", ScalarFunctionTest { + name: "mul-int64-passed", + columns: vec![ + Series::from_data(vec![4i64, 3, 2]), + Series::from_data(vec![1i64, 2, 3]), + ], + expect: Series::from_data(vec![4i64, 6, 6]), + error: "", + }), + ("/", ScalarFunctionTest { + name: "div-int64-passed", + columns: vec![ + Series::from_data(vec![4i64, 3, 2]), + Series::from_data(vec![1i64, 2, 3]), + ], + expect: Series::from_data(vec![4.0, 1.5, 0.6666666666666666]), + error: "", + }), + ("div", ScalarFunctionTest { + name: "intdiv-int64-passed", + columns: vec![ + Series::from_data(vec![4i64, 3, 2]), + Series::from_data(vec![1i64, 2, 3]), + ], + expect: Series::from_data(vec![4i64, 1, 0]), + error: "", + }), + ("modulo", ScalarFunctionTest { + name: "mod-int64-passed", + columns: vec![ + Series::from_data(vec![4i64, 3, 2]), + Series::from_data(vec![1i64, 2, 3]), + ], + expect: Series::from_data(vec![0i64, 1, 2]), + error: "", + }), ]; - for (test_function, test) in tests { - test_scalar_functions(test_function, &[test], true)? + for (op, test) in tests { + test_scalar_functions(op, &[test])?; } Ok(()) @@ -141,241 +121,241 @@ fn test_arithmetic_date_interval() -> Result<()> { }; let tests = vec![ - ( - ArithmeticPlusFunction::try_create_func("", &[ - &Date16Type::arc(), - &IntervalType::arc(IntervalKind::Year), - ])?, - ScalarFunctionTest { - name: "date16-add-years-passed", - columns: vec![ + ("+", ScalarFunctionWithFieldTest { + name: "date16-add-years-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_day16(2020, 2, 29), /* 2020-2-29 */ to_day16(2016, 2, 29), /* 2016-2-29 */ ]), + DataField::new("dummy_0", Date16Type::arc()), + ), + ColumnWithField::new( Series::from_data(vec![-1i64, 4]), - ], - expect: Series::from_data(vec![ - to_day16(2019, 2, 28), /* 2019-2-28 */ - to_day16(2020, 2, 29), /* 2020-2-29 */ - ]), - error: "", - }, - ), - ( - ArithmeticMinusFunction::try_create_func("", &[ - &Date32Type::arc(), - &IntervalType::arc(IntervalKind::Year), - ])?, - ScalarFunctionTest { - name: "date32-sub-years-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Year)), + ), + ], + expect: Series::from_data(vec![ + to_day16(2019, 2, 28), /* 2019-2-28 */ + to_day16(2020, 2, 29), /* 2020-2-29 */ + ]), + error: "", + }), + ("-", ScalarFunctionWithFieldTest { + name: "date32-sub-years-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_day32(2400, 2, 29), /* 2400-2-29 */ to_day32(1960, 2, 29), /* 1960-2-29 */ ]), + DataField::new("dummy_0", Date32Type::arc()), + ), + ColumnWithField::new( Series::from_data(vec![1i64, -4]), - ], - expect: Series::from_data(vec![ - to_day32(2399, 2, 28), /* 2399-2-28 */ - to_day32(1964, 2, 29), /* 1964-2-29 */ - ]), - error: "", - }, - ), - ( - ArithmeticPlusFunction::try_create_func("", &[ - &DateTime32Type::arc(None), - &IntervalType::arc(IntervalKind::Year), - ])?, - ScalarFunctionTest { - name: "datetime32-add-years-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Year)), + ), + ], + expect: Series::from_data(vec![ + to_day32(2399, 2, 28), /* 2399-2-28 */ + to_day32(1964, 2, 29), /* 1964-2-29 */ + ]), + error: "", + }), + ("+", ScalarFunctionWithFieldTest { + name: "datetime32-add-years-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_seconds(2020, 2, 29, 10, 30, 00), /* 2020-2-29 10:30:00 */ to_seconds(2021, 2, 28, 10, 30, 00), /* 2021-2-28 10:30:00 */ ]), + DataField::new("dummy_0", DateTime32Type::arc(None)), + ), + ColumnWithField::new( Series::from_data(vec![1i64, -1]), - ], - expect: Series::from_data(vec![ - to_seconds(2021, 2, 28, 10, 30, 00), /* 2021-2-28 10:30:00 */ - to_seconds(2020, 2, 28, 10, 30, 00), /* 2020-2-28 10:30:00 */ - ]), - error: "", - }, - ), - ( - ArithmeticMinusFunction::try_create_func("", &[ - &DateTime64Type::arc(3, None), - &IntervalType::arc(IntervalKind::Year), - ])?, - ScalarFunctionTest { - name: "datetime64-sub-years-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Year)), + ), + ], + expect: Series::from_data(vec![ + to_seconds(2021, 2, 28, 10, 30, 00), /* 2021-2-28 10:30:00 */ + to_seconds(2020, 2, 28, 10, 30, 00), /* 2020-2-28 10:30:00 */ + ]), + error: "", + }), + ("-", ScalarFunctionWithFieldTest { + name: "datetime64-sub-years-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_milliseconds(2020, 2, 29, 10, 30, 00, 000), /* 2020-2-29 10:30:00.000 */ to_milliseconds(1960, 2, 29, 10, 30, 00, 000), /* 1960-2-29 10:30:00.000 */ ]), + DataField::new("dummy_0", DateTime64Type::arc(3, None)), + ), + ColumnWithField::new( Series::from_data(vec![1i64, -4]), - ], - expect: Series::from_data(vec![ - to_milliseconds(2019, 2, 28, 10, 30, 00, 000), /* 2019-2-28 10:30:00.000 */ - to_milliseconds(1964, 2, 29, 10, 30, 00, 000), /* 1964-2-29 10:30:00.000 */ - ]), - error: "", - }, - ), - ( - ArithmeticPlusFunction::try_create_func("", &[ - &Date16Type::arc(), - &IntervalType::arc(IntervalKind::Month), - ])?, - ScalarFunctionTest { - name: "date16-add-months-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Year)), + ), + ], + expect: Series::from_data(vec![ + to_milliseconds(2019, 2, 28, 10, 30, 00, 000), /* 2019-2-28 10:30:00.000 */ + to_milliseconds(1964, 2, 29, 10, 30, 00, 000), /* 1964-2-29 10:30:00.000 */ + ]), + error: "", + }), + ("+", ScalarFunctionWithFieldTest { + name: "date16-add-months-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_day16(2020, 3, 31), /* 2020-3-31 */ to_day16(2000, 1, 31), /* 2000-1-31 */ ]), + DataField::new("dummy_0", Date16Type::arc()), + ), + ColumnWithField::new( Series::from_data(vec![-1i64, 241]), - ], - expect: Series::from_data(vec![ - to_day16(2020, 2, 29), /* 2020-2-29 */ - to_day16(2020, 2, 29), /* 2020-2-29 */ - ]), - error: "", - }, - ), - ( - ArithmeticPlusFunction::try_create_func("", &[ - &DateTime32Type::arc(None), - &IntervalType::arc(IntervalKind::Month), - ])?, - ScalarFunctionTest { - name: "datetime32-add-months-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Month)), + ), + ], + expect: Series::from_data(vec![ + to_day16(2020, 2, 29), /* 2020-2-29 */ + to_day16(2020, 2, 29), /* 2020-2-29 */ + ]), + error: "", + }), + ("+", ScalarFunctionWithFieldTest { + name: "datetime32-add-months-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_seconds(2020, 3, 31, 10, 30, 00), /* 2020-3-31 10:30:00 */ to_seconds(2000, 1, 31, 10, 30, 00), /* 2000-1-31 10:30:00 */ ]), + DataField::new("dummy_0", DateTime32Type::arc(None)), + ), + ColumnWithField::new( Series::from_data(vec![-1i64, 241]), - ], - expect: Series::from_data(vec![ - to_seconds(2020, 2, 29, 10, 30, 00), /* 2020-2-29 10:30:00 */ - to_seconds(2020, 2, 29, 10, 30, 00), /* 2020-2-29 10:30:00 */ - ]), - error: "", - }, - ), - ( - ArithmeticMinusFunction::try_create_func("", &[ - &Date32Type::arc(), - &IntervalType::arc(IntervalKind::Day), - ])?, - ScalarFunctionTest { - name: "date32-sub-days-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Month)), + ), + ], + expect: Series::from_data(vec![ + to_seconds(2020, 2, 29, 10, 30, 00), /* 2020-2-29 10:30:00 */ + to_seconds(2020, 2, 29, 10, 30, 00), /* 2020-2-29 10:30:00 */ + ]), + error: "", + }), + ("-", ScalarFunctionWithFieldTest { + name: "date32-sub-days-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_day32(2400, 2, 29), /* 2400-2-29 */ to_day32(1960, 2, 29), /* 1960-2-29 */ ]), + DataField::new("dummy_0", Date32Type::arc()), + ), + ColumnWithField::new( Series::from_data(vec![30i64, -30]), - ], - expect: Series::from_data(vec![ - to_day32(2400, 1, 30), /* 2400-1-30 */ - to_day32(1960, 3, 30), /* 1960-3-30 */ - ]), - error: "", - }, - ), - ( - ArithmeticPlusFunction::try_create_func("", &[ - &DateTime64Type::arc(3, None), - &IntervalType::arc(IntervalKind::Day), - ])?, - ScalarFunctionTest { - name: "datetime64-add-days-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Day)), + ), + ], + expect: Series::from_data(vec![ + to_day32(2400, 1, 30), /* 2400-1-30 */ + to_day32(1960, 3, 30), /* 1960-3-30 */ + ]), + error: "", + }), + ("+", ScalarFunctionWithFieldTest { + name: "datetime64-add-days-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_milliseconds(2020, 2, 29, 10, 30, 00, 000), /* 2020-2-29 10:30:00.000 */ to_milliseconds(1960, 2, 29, 10, 30, 00, 000), /* 1960-2-29 10:30:00.000 */ ]), + DataField::new("dummy_0", DateTime64Type::arc(3, None)), + ), + ColumnWithField::new( Series::from_data(vec![-30i64, 30]), - ], - expect: Series::from_data(vec![ - to_milliseconds(2020, 1, 30, 10, 30, 00, 000), /* 2020-1-30 10:30:00.000 */ - to_milliseconds(1960, 3, 30, 10, 30, 00, 000), /* 1960-3-30 10:30:00.000 */ - ]), - error: "", - }, - ), - ( - ArithmeticPlusFunction::try_create_func("", &[ - &Date16Type::arc(), - &IntervalType::arc(IntervalKind::Hour), - ])?, - ScalarFunctionTest { - name: "date16-add-hours-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Day)), + ), + ], + expect: Series::from_data(vec![ + to_milliseconds(2020, 1, 30, 10, 30, 00, 000), /* 2020-1-30 10:30:00.000 */ + to_milliseconds(1960, 3, 30, 10, 30, 00, 000), /* 1960-3-30 10:30:00.000 */ + ]), + error: "", + }), + ("+", ScalarFunctionWithFieldTest { + name: "date16-add-hours-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_day16(2020, 3, 1), /* 2020-3-31 */ to_day16(2000, 1, 31), /* 2000-1-31 */ ]), + DataField::new("dummy_0", Date16Type::arc()), + ), + ColumnWithField::new( Series::from_data(vec![-1i64, 1]), - ], - expect: Series::from_data(vec![ - to_seconds(2020, 2, 29, 23, 00, 00), /* 2020-2-29 23:00:00 */ - to_seconds(2000, 1, 31, 1, 00, 00), /* 2000-1-31 1:00:00 */ - ]), - error: "", - }, - ), - ( - ArithmeticMinusFunction::try_create_func("", &[ - &Date32Type::arc(), - &IntervalType::arc(IntervalKind::Minute), - ])?, - ScalarFunctionTest { - name: "date32-sub-minutes-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Hour)), + ), + ], + expect: Series::from_data(vec![ + to_seconds(2020, 2, 29, 23, 00, 00), /* 2020-2-29 23:00:00 */ + to_seconds(2000, 1, 31, 1, 00, 00), /* 2000-1-31 1:00:00 */ + ]), + error: "", + }), + ("-", ScalarFunctionWithFieldTest { + name: "date32-sub-minutes-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_day32(2400, 2, 29), /* 2400-2-29 */ to_day32(1960, 2, 29), /* 1960-2-29 */ ]), + DataField::new("dummy_0", Date32Type::arc()), + ), + ColumnWithField::new( Series::from_data(vec![61i64, -30]), - ], - expect: Series::from_data(vec![ - to_milliseconds(2400, 2, 28, 22, 59, 00, 000) / 1000, /* 2400-2-28 22:59:00 */ - to_milliseconds(1960, 2, 29, 00, 30, 00, 000) / 1000, /* 1960-2-29 00:30:00 */ - ]), - error: "", - }, - ), - ( - ArithmeticMinusFunction::try_create_func("", &[ - &DateTime32Type::arc(None), - &IntervalType::arc(IntervalKind::Second), - ])?, - ScalarFunctionTest { - name: "datetime32-sub-seconds-passed", - columns: vec![ + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Minute)), + ), + ], + expect: Series::from_data(vec![ + to_milliseconds(2400, 2, 28, 22, 59, 00, 000) / 1000, /* 2400-2-28 22:59:00 */ + to_milliseconds(1960, 2, 29, 00, 30, 00, 000) / 1000, /* 1960-2-29 00:30:00 */ + ]), + error: "", + }), + ("-", ScalarFunctionWithFieldTest { + name: "datetime32-sub-seconds-passed", + columns: vec![ + ColumnWithField::new( Series::from_data(vec![ to_seconds(2020, 3, 31, 10, 30, 00), /* 2020-3-31 10:30:00 */ to_seconds(2000, 1, 31, 10, 30, 00), /* 2000-1-31 10:30:00 */ ]), + DataField::new("dummy_0", DateTime32Type::arc(None)), + ), + ColumnWithField::new( Series::from_data(vec![-120i64, 23]), - ], - expect: Series::from_data(vec![ - to_seconds(2020, 3, 31, 10, 32, 00), /* 2020-3-31 10:32:00 */ - to_seconds(2000, 1, 31, 10, 29, 37), /* 2000-1-31 10:29:37 */ - ]), - error: "", - }, - ), + DataField::new("dummy_1", IntervalType::arc(IntervalKind::Second)), + ), + ], + expect: Series::from_data(vec![ + to_seconds(2020, 3, 31, 10, 32, 00), /* 2020-3-31 10:32:00 */ + to_seconds(2000, 1, 31, 10, 29, 37), /* 2000-1-31 10:29:37 */ + ]), + error: "", + }), ]; - for (test_function, test) in tests { - test_scalar_functions(test_function, &[test], true)? + for (op, test) in tests { + test_scalar_functions_with_type(op, &[test])?; } Ok(()) diff --git a/common/functions/tests/it/scalars/comparisons.rs b/common/functions/tests/it/scalars/comparisons.rs index 167cfccc691a..7247dbf5f981 100644 --- a/common/functions/tests/it/scalars/comparisons.rs +++ b/common/functions/tests/it/scalars/comparisons.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use super::scalar_function2_test::test_scalar_functions; -use super::scalar_function2_test::ScalarFunctionTest; +use super::scalar_function_test::test_scalar_functions; +use super::scalar_function_test::ScalarFunctionTest; #[test] fn test_eq_comparison_function() -> Result<()> { @@ -31,11 +30,7 @@ fn test_eq_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonEqFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - &tests, - true, - ) + test_scalar_functions("=", &tests) } #[test] @@ -50,11 +45,7 @@ fn test_gt_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonGtFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - &tests, - true, - ) + test_scalar_functions(">", &tests) } #[test] @@ -69,11 +60,7 @@ fn test_gt_eq_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonGtEqFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - &tests, - true, - ) + test_scalar_functions(">=", &tests) } #[test] @@ -88,11 +75,7 @@ fn test_lt_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonLtFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - &tests, - true, - ) + test_scalar_functions("<", &tests) } #[test] @@ -107,11 +90,7 @@ fn test_lt_eq_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonLtEqFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - &tests, - true, - ) + test_scalar_functions("<=", &tests) } #[test] @@ -126,11 +105,7 @@ fn test_not_eq_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonNotEqFunction::try_create_func("", &[&Int64Type::arc(), &Int64Type::arc()])?, - &tests, - true, - ) + test_scalar_functions("<>", &tests) } #[test] @@ -145,11 +120,7 @@ fn test_like_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonLikeFunction::try_create_func("", &[&StringType::arc(), &StringType::arc()])?, - &tests, - true, - ) + test_scalar_functions("like", &tests) } #[test] @@ -164,11 +135,7 @@ fn test_not_like_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonNotLikeFunction::try_create_func("", &[&StringType::arc(), &StringType::arc()])?, - &tests, - true, - ) + test_scalar_functions("not like", &tests) } #[test] @@ -183,11 +150,7 @@ fn test_regexp_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonRegexpFunction::try_create_func("", &[&StringType::arc(), &StringType::arc()])?, - &tests, - true, - ) + test_scalar_functions("regexp", &tests) } #[test] @@ -202,12 +165,5 @@ fn test_not_regexp_comparison_function() -> Result<()> { error: "", }]; - test_scalar_functions( - ComparisonNotRegexpFunction::try_create_func("", &[ - &StringType::arc(), - &StringType::arc(), - ])?, - &tests, - true, - ) + test_scalar_functions("not regexp", &tests) } diff --git a/common/functions/tests/it/scalars/conditionals.rs b/common/functions/tests/it/scalars/conditionals.rs index 5dc4ef6455e0..09798d9bd3e4 100644 --- a/common/functions/tests/it/scalars/conditionals.rs +++ b/common/functions/tests/it/scalars/conditionals.rs @@ -15,10 +15,9 @@ use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::IfFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_if_function() -> Result<()> { @@ -105,5 +104,5 @@ fn test_if_function() -> Result<()> { }, ]; - test_scalar_functions(IfFunction::try_create("if")?, &tests, false) + test_scalar_functions("if", &tests) } diff --git a/common/functions/tests/it/scalars/dates/date.rs b/common/functions/tests/it/scalars/dates/date.rs index 19d87902bd6c..e1e10bd48f57 100644 --- a/common/functions/tests/it/scalars/dates/date.rs +++ b/common/functions/tests/it/scalars/dates/date.rs @@ -15,32 +15,33 @@ use common_datavalues::prelude::*; use common_datavalues::ColumnWithField; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions_with_type; -use crate::scalars::scalar_function2_test::ScalarFunctionWithFieldTest; +use crate::scalars::scalar_function_test::test_scalar_functions_with_type; +use crate::scalars::scalar_function_test::ScalarFunctionWithFieldTest; #[test] fn test_round_function() -> Result<()> { - let mut tests = vec![]; - - for r in &[1, 60, 60 * 10, 60 * 15, 60 * 30, 60 * 60, 60 * 60 * 24] { - tests.push(( - RoundFunction::try_create("toStartOfCustom", *r)?, - ScalarFunctionWithFieldTest { - name: "test-timeSlot-now", - columns: vec![ColumnWithField::new( - Series::from_data(vec![1630812366u32, 1630839682u32]), - DataField::new("dummy_1", DateTime32Type::arc(None)), - )], - expect: Series::from_data(vec![1630812366u32 / r * r, 1630839682u32 / r * r]), - error: "", - }, - )); - } - - for (test_function, test) in tests { - test_scalar_functions_with_type(test_function, &[test], true)?; + let ops = vec![ + "toStartOfSecond", + "toStartOfMinute", + "toStartOfTenMinutes", + "toStartOfFifteenMinutes", + "timeSlot", + "toStartOfHour", + "toStartOfDay", + ]; + let rounds = vec![1, 60, 60 * 10, 60 * 15, 60 * 30, 60 * 60, 60 * 60 * 24]; + + for (op, r) in ops.iter().cloned().zip(rounds.iter()) { + test_scalar_functions_with_type(op, &[ScalarFunctionWithFieldTest { + name: "test-timeSlot-now", + columns: vec![ColumnWithField::new( + Series::from_data(vec![1630812366u32, 1630839682u32]), + DataField::new("dummy_1", DateTime32Type::arc(None)), + )], + expect: Series::from_data(vec![1630812366u32 / r * r, 1630839682u32 / r * r]), + error: "", + }])?; } Ok(()) @@ -58,9 +59,5 @@ fn test_to_start_of_function() -> Result<()> { error: "", }]; - test_scalar_functions_with_type( - ToStartOfQuarterFunction::try_create("toStartOfWeek")?, - &test, - true, - ) + test_scalar_functions_with_type("toStartOfQuarter", &test) } diff --git a/common/functions/tests/it/scalars/dates/date_function.rs b/common/functions/tests/it/scalars/dates/date_function.rs index e1cfbe50da7b..e70c785752c6 100644 --- a/common/functions/tests/it/scalars/dates/date_function.rs +++ b/common/functions/tests/it/scalars/dates/date_function.rs @@ -17,14 +17,12 @@ use std::sync::Arc; use common_datavalues::prelude::*; use common_datavalues::ColumnWithField; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions_with_type; -use crate::scalars::scalar_function2_test::ScalarFunctionWithFieldTest; +use crate::scalars::scalar_function_test::test_scalar_functions_with_type; +use crate::scalars::scalar_function_test::ScalarFunctionWithFieldTest; #[test] fn test_toyyyymm_function() -> Result<()> { - // use common_datavalues::types::*; let tests = vec![ ScalarFunctionWithFieldTest { name: "test_toyyyymm_date16", @@ -82,7 +80,7 @@ fn test_toyyyymm_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToYYYYMMFunction::try_create("c")?, &tests, true) + test_scalar_functions_with_type("toYYYYMM", &tests) } #[test] @@ -171,7 +169,7 @@ fn test_to_yyyymmdd_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToYYYYMMDDFunction::try_create("c")?, &tests, true) + test_scalar_functions_with_type("toYYYYMMDD", &tests) } #[test] @@ -233,7 +231,7 @@ fn test_toyyyymmddhhmmss_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToYYYYMMDDhhmmssFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toYYYYMMDDhhmmss", &tests) } #[test] @@ -295,7 +293,7 @@ fn test_tomonth_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToMonthFunction::try_create("c")?, &tests, true) + test_scalar_functions_with_type("toMonth", &tests) } #[test] @@ -357,7 +355,7 @@ fn test_todayofyear_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToDayOfYearFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toDayOfYear", &tests) } #[test] @@ -419,7 +417,7 @@ fn test_todatofweek_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToDayOfWeekFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toDayOfWeek", &tests) } #[test] @@ -481,7 +479,7 @@ fn test_todayofmonth_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToDayOfMonthFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toDayOfMonth", &tests) } #[test] @@ -543,7 +541,7 @@ fn test_tohour_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToHourFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toHour", &tests) } #[test] @@ -605,7 +603,7 @@ fn test_tominute_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToMinuteFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toMinute", &tests) } #[test] @@ -667,7 +665,7 @@ fn test_tosecond_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToSecondFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toSecond", &tests) } #[test] @@ -702,5 +700,5 @@ fn test_tomonday_function() -> Result<()> { }, ]; - test_scalar_functions_with_type(ToMondayFunction::try_create("a")?, &tests, true) + test_scalar_functions_with_type("toMonday", &tests) } diff --git a/common/functions/tests/it/scalars/expressions.rs b/common/functions/tests/it/scalars/expressions.rs index c4208ed22c7f..ea53d5bed5e9 100644 --- a/common/functions/tests/it/scalars/expressions.rs +++ b/common/functions/tests/it/scalars/expressions.rs @@ -19,142 +19,199 @@ use common_exception::Result; use common_functions::scalars::*; use serde_json::json; -use super::scalar_function2_test::ScalarFunctionWithFieldTest; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::test_scalar_functions_with_type; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use super::scalar_function_test::ScalarFunctionWithFieldTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::test_scalar_functions_with_type; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_cast_function() -> Result<()> { let tests = vec![ - (CastFunction::create("cast", "int8")?, ScalarFunctionTest { + ("toInt8", ScalarFunctionTest { name: "cast-int64-to-int8-passed", columns: vec![Series::from_data(vec![4i64, 3, 2, 4])], expect: Series::from_data(vec![4i8, 3, 2, 4]), error: "", }), - (CastFunction::create("cast", "int8")?, ScalarFunctionTest { + ("toInt8", ScalarFunctionTest { name: "cast-string-to-int8-passed", columns: vec![Series::from_data(vec!["4", "3", "2", "4"])], expect: Series::from_data(vec![4i8, 3, 2, 4]), error: "", }), - (CastFunction::create("cast", "int16")?, ScalarFunctionTest { + ("toInt16", ScalarFunctionTest { name: "cast-string-to-int16-passed", columns: vec![Series::from_data(vec!["4", "3", "2", "4"])], expect: Series::from_data(vec![4i16, 3, 2, 4]), error: "", }), - (CastFunction::create("cast", "int32")?, ScalarFunctionTest { + ("toInt32", ScalarFunctionTest { name: "cast-string-to-int32-passed", columns: vec![Series::from_data(vec!["4", "3", "2", "4"])], expect: Series::from_data(vec![4i32, 3, 2, 4]), error: "", }), - (CastFunction::create("cast", "int32")?, ScalarFunctionTest { + ("toInt32", ScalarFunctionTest { name: "cast-string-to-int32-error-passed", columns: vec![Series::from_data(vec!["X4", "3", "2", "4"])], expect: Series::from_data(vec![4i32, 3, 2, 4]), error: "Cast error happens in casting from String to Int32", }), - (CastFunction::create("cast", "int32")?, ScalarFunctionTest { + ("toInt32", ScalarFunctionTest { name: "cast-string-to-int32-error-as_null-passed", columns: vec![Series::from_data(vec!["X4", "3", "2", "4"])], expect: Series::from_data(vec![Some(0i32), Some(3), Some(2), Some(4)]), error: "Cast error happens in casting from String to Int32", }), - (CastFunction::create("cast", "int64")?, ScalarFunctionTest { + ("toInt64", ScalarFunctionTest { name: "cast-string-to-int64-passed", columns: vec![Series::from_data(vec!["4", "3", "2", "4"])], expect: Series::from_data(vec![4i64, 3, 2, 4]), error: "", }), + ("toDate16", ScalarFunctionTest { + name: "cast-string-to-date16-passed", + columns: vec![Series::from_data(vec!["2021-03-05", "2021-10-24"])], + expect: Series::from_data(vec![18691u16, 18924]), + error: "", + }), + ("toDate32", ScalarFunctionTest { + name: "cast-string-to-date32-passed", + columns: vec![Series::from_data(vec!["2021-03-05", "2021-10-24"])], + expect: Series::from_data(vec![18691i32, 18924]), + error: "", + }), + ("toDateTime32", ScalarFunctionTest { + name: "cast-string-to-datetime32-passed", + columns: vec![Series::from_data(vec![ + "2021-03-05 01:01:01", + "2021-10-24 10:10:10", + ])], + expect: Series::from_data(vec![1614906061u32, 1635070210]), + error: "", + }), + ("toDateTime64", ScalarFunctionTest { + name: "cast-string-to-datetime64-passed", + columns: vec![Series::from_data(vec![ + "2021-03-05 01:01:01.123", + "2021-10-24 10:10:10.123", + ])], + expect: Series::from_data(vec![1614906061123i64, 1635070210123]), + error: "", + }), + ]; + + for (op, test) in tests { + test_scalar_functions(op, &[test])?; + } + + Ok(()) +} + +#[test] +fn test_datetime_cast_function() -> Result<()> { + let tests = vec![ + ("toString", ScalarFunctionWithFieldTest { + name: "cast-date32-to-string-passed", + columns: vec![ColumnWithField::new( + Series::from_data(vec![18691i32, 18924]), + DataField::new("dummy_1", Date32Type::arc()), + )], + expect: Series::from_data(vec!["2021-03-05", "2021-10-24"]), + error: "", + }), + ("toString", ScalarFunctionWithFieldTest { + name: "cast-datetime-to-string-passed", + columns: vec![ColumnWithField::new( + Series::from_data(vec![1614906061u32, 1635070210]), + DataField::new("dummy_1", DateTime32Type::arc(None)), + )], + expect: Series::from_data(vec!["2021-03-05 01:01:01", "2021-10-24 10:10:10"]), + error: "", + }), + ]; + + for (op, test) in tests { + test_scalar_functions_with_type(op, &[test])?; + } + + Ok(()) +} + +#[test] +fn test_cast_variant_function() -> Result<()> { + let tests = vec![ ( - CastFunction::create("cast", "date16")?, - ScalarFunctionTest { - name: "cast-string-to-date16-passed", - columns: vec![Series::from_data(vec!["2021-03-05", "2021-10-24"])], - expect: Series::from_data(vec![18691u16, 18924]), - error: "", - }, - ), - ( - CastFunction::create("cast", "date32")?, - ScalarFunctionTest { - name: "cast-string-to-date32-passed", - columns: vec![Series::from_data(vec!["2021-03-05", "2021-10-24"])], - expect: Series::from_data(vec![18691i32, 18924]), - error: "", - }, - ), - ( - CastFunction::create("cast", "datetime32")?, - ScalarFunctionTest { - name: "cast-string-to-datetime32-passed", - columns: vec![Series::from_data(vec![ - "2021-03-05 01:01:01", - "2021-10-24 10:10:10", - ])], - expect: Series::from_data(vec![1614906061u32, 1635070210]), - error: "", - }, - ), - ( - CastFunction::create("cast", "datetime64")?, - ScalarFunctionTest { - name: "cast-string-to-datetime64-passed", - columns: vec![Series::from_data(vec![ - "2021-03-05 01:01:01.123", - "2021-10-24 10:10:10.123", - ])], - expect: Series::from_data(vec![1614906061123i64, 1635070210123]), - error: "", + CastFunction::create("cast", "variant")?, + ScalarFunctionWithFieldTest { + name: "cast-date32-to-variant-error", + columns: vec![ColumnWithField::new( + Series::from_data(vec![18691i32, 18924]), + DataField::new("dummy_1", Date32Type::arc()), + )], + expect: Arc::new(NullColumn::new(2)), + error: "Expression type does not match column data type, expecting VARIANT but got Date32", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-bool-to-variant-passed", - columns: vec![Series::from_data(vec![true, false])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![true, false]), + DataField::new("dummy_1", BooleanType::arc()), + )], expect: Series::from_data(vec![json!(true), json!(false)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-int8-to-variant-passed", - columns: vec![Series::from_data(vec![-128i8, 127])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![-128i8, 127]), + DataField::new("dummy_1", Int8Type::arc()), + )], expect: Series::from_data(vec![json!(-128i8), json!(127i8)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-int16-to-variant-passed", - columns: vec![Series::from_data(vec![-32768i16, 32767])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![-32768i16, 32767]), + DataField::new("dummy_1", Int16Type::arc()), + )], expect: Series::from_data(vec![json!(-32768i16), json!(32767i16)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-int32-to-variant-passed", - columns: vec![Series::from_data(vec![-2147483648i32, 2147483647])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![-2147483648i32, 2147483647]), + DataField::new("dummy_1", Int32Type::arc()), + )], expect: Series::from_data(vec![json!(-2147483648i32), json!(2147483647i32)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-int64-to-variant-passed", - columns: vec![Series::from_data(vec![ - -9223372036854775808i64, - 9223372036854775807, - ])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![ + -9223372036854775808i64, + 9223372036854775807, + ]), + DataField::new("dummy_1", Int64Type::arc()), + )], expect: Series::from_data(vec![ json!(-9223372036854775808i64), json!(9223372036854775807i64), @@ -164,57 +221,73 @@ fn test_cast_function() -> Result<()> { ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-uint8-to-variant-passed", - columns: vec![Series::from_data(vec![0u8, 255])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![0u8, 255]), + DataField::new("dummy_1", UInt8Type::arc()), + )], expect: Series::from_data(vec![json!(0u8), json!(255u8)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-uint16-to-variant-passed", - columns: vec![Series::from_data(vec![0u16, 65535])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![0u16, 65535]), + DataField::new("dummy_1", UInt16Type::arc()), + )], expect: Series::from_data(vec![json!(0u16), json!(65535u16)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-uint32-to-variant-passed", - columns: vec![Series::from_data(vec![0u32, 4294967295])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![0u32, 4294967295]), + DataField::new("dummy_1", UInt32Type::arc()), + )], expect: Series::from_data(vec![json!(0u32), json!(4294967295u32)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-uint64-to-variant-passed", - columns: vec![Series::from_data(vec![0u64, 18446744073709551615])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![0u64, 18446744073709551615]), + DataField::new("dummy_1", UInt64Type::arc()), + )], expect: Series::from_data(vec![json!(0u64), json!(18446744073709551615u64)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-float32-to-variant-passed", - columns: vec![Series::from_data(vec![0.12345679f32, 12.34])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![0.12345679f32, 12.34]), + DataField::new("dummy_1", Float32Type::arc()), + )], expect: Series::from_data(vec![json!(0.12345679f32), json!(12.34f32)]), error: "", }, ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { + ScalarFunctionWithFieldTest { name: "cast-float64-to-variant-passed", - columns: vec![Series::from_data(vec![ - 0.12345678912121212f64, - 12.345678912, - ])], + columns: vec![ColumnWithField::new( + Series::from_data(vec![0.12345678912121212f64, + 12.345678912,]), + DataField::new("dummy_1", Float64Type::arc()), + )], expect: Series::from_data(vec![ json!(0.12345678912121212f64), json!(12.345678912f64), @@ -224,68 +297,36 @@ fn test_cast_function() -> Result<()> { ), ( CastFunction::create("cast", "variant")?, - ScalarFunctionTest { - name: "cast-string-to-variant-error", - columns: vec![Series::from_data(vec![ - "abc", - "123", - ])], - expect: Arc::new(NullColumn::new(2)), - error: "Expression type does not match column data type, expecting VARIANT but got String", - }, - ), - ]; - - for (test_func, test) in tests { - test_scalar_functions(test_func, &[test], false)?; - } - - Ok(()) -} - -#[test] -fn test_datetime_cast_function() -> Result<()> { - let tests = vec![ - ( - CastFunction::create("cast", "string")?, - ScalarFunctionWithFieldTest { - name: "cast-date32-to-string-passed", - columns: vec![ColumnWithField::new( - Series::from_data(vec![18691i32, 18924]), - DataField::new("dummy_1", Date32Type::arc()), - )], - expect: Series::from_data(vec!["2021-03-05", "2021-10-24"]), - error: "", - }, - ), - ( - CastFunction::create("cast", "string")?, ScalarFunctionWithFieldTest { - name: "cast-datetime-to-string-passed", - columns: vec![ColumnWithField::new( - Series::from_data(vec![1614906061u32, 1635070210]), - DataField::new("dummy_1", DateTime32Type::arc(None)), - )], - expect: Series::from_data(vec!["2021-03-05 01:01:01", "2021-10-24 10:10:10"]), - error: "", - }, - ), - ( - CastFunction::create("cast", "variant")?, - ScalarFunctionWithFieldTest { - name: "cast-date32-to-variant-error", + name: "cast-string-to-variant-error", columns: vec![ColumnWithField::new( - Series::from_data(vec![18691i32, 18924]), - DataField::new("dummy_1", Date32Type::arc()), + Series::from_data(vec![ + "abc", + "123", + ]), + DataField::new("dummy_1", StringType::arc()), )], expect: Arc::new(NullColumn::new(2)), - error: "Expression type does not match column data type, expecting VARIANT but got Date32", + error: "Expression type does not match column data type, expecting VARIANT but got String", }, ), ]; for (test_func, test) in tests { - test_scalar_functions_with_type(test_func, &[test], false)?; + match test_func.eval( + &test.columns, + test.columns[0].column().len(), + FunctionContext { tz: None }, + ) { + Ok(v) => { + let v = v.convert_full_column(); + + assert_eq!(test.expect, v, "{}", test.name); + } + Err(cause) => { + assert_eq!(test.error, cause.message(), "{}", test.name); + } + } } Ok(()) @@ -294,7 +335,7 @@ fn test_datetime_cast_function() -> Result<()> { #[test] fn test_variant_cast_function() -> Result<()> { let tests = vec![ - (CastFunction::create("cast", "uint8")?, ScalarFunctionTest { + ("toUInt8", ScalarFunctionTest { name: "cast-variant-to-uint8-passed", columns: vec![Series::from_data(vec![ json!(4u64), @@ -305,63 +346,51 @@ fn test_variant_cast_function() -> Result<()> { expect: Series::from_data(vec![4u8, 3, 2, 4]), error: "", }), - ( - CastFunction::create("cast", "uint16")?, - ScalarFunctionTest { - name: "cast-variant-to-uint16-passed", - columns: vec![Series::from_data(vec![ - json!(4u64), - json!(3u64), - json!("2"), - json!("4"), - ])], - expect: Series::from_data(vec![4u16, 3, 2, 4]), - error: "", - }, - ), - ( - CastFunction::create("cast", "uint32")?, - ScalarFunctionTest { - name: "cast-variant-to-uint32-passed", - columns: vec![Series::from_data(vec![ - json!(4u64), - json!(3u64), - json!("2"), - json!("4"), - ])], - expect: Series::from_data(vec![4u32, 3, 2, 4]), - error: "", - }, - ), - ( - CastFunction::create("cast", "uint64")?, - ScalarFunctionTest { - name: "cast-variant-to-uint64-passed", - columns: vec![Series::from_data(vec![ - json!(4u64), - json!(3u64), - json!("2"), - json!("4"), - ])], - expect: Series::from_data(vec![4u64, 3, 2, 4]), - error: "", - }, - ), - ( - CastFunction::create("cast", "uint64")?, - ScalarFunctionTest { - name: "cast-variant-to-uint64-error", - columns: vec![Series::from_data(vec![ - json!("X4"), - json!(3u64), - json!("2"), - json!("4"), - ])], - expect: Series::from_data(vec![4u64, 3, 2, 4]), - error: "Cast error happens in casting from Variant to UInt64", - }, - ), - (CastFunction::create("cast", "int8")?, ScalarFunctionTest { + ("toUInt16", ScalarFunctionTest { + name: "cast-variant-to-uint16-passed", + columns: vec![Series::from_data(vec![ + json!(4u64), + json!(3u64), + json!("2"), + json!("4"), + ])], + expect: Series::from_data(vec![4u16, 3, 2, 4]), + error: "", + }), + ("toUInt32", ScalarFunctionTest { + name: "cast-variant-to-uint32-passed", + columns: vec![Series::from_data(vec![ + json!(4u64), + json!(3u64), + json!("2"), + json!("4"), + ])], + expect: Series::from_data(vec![4u32, 3, 2, 4]), + error: "", + }), + ("toUInt64", ScalarFunctionTest { + name: "cast-variant-to-uint64-passed", + columns: vec![Series::from_data(vec![ + json!(4u64), + json!(3u64), + json!("2"), + json!("4"), + ])], + expect: Series::from_data(vec![4u64, 3, 2, 4]), + error: "", + }), + ("toUInt64", ScalarFunctionTest { + name: "cast-variant-to-uint64-error", + columns: vec![Series::from_data(vec![ + json!("X4"), + json!(3u64), + json!("2"), + json!("4"), + ])], + expect: Series::from_data(vec![4u64, 3, 2, 4]), + error: "Cast error happens in casting from Variant to UInt64", + }), + ("toInt8", ScalarFunctionTest { name: "cast-variant-to-int8-passed", columns: vec![Series::from_data(vec![ json!(4i64), @@ -372,7 +401,7 @@ fn test_variant_cast_function() -> Result<()> { expect: Series::from_data(vec![4i8, -3, 2, -4]), error: "", }), - (CastFunction::create("cast", "int16")?, ScalarFunctionTest { + ("toInt16", ScalarFunctionTest { name: "cast-variant-to-int16-passed", columns: vec![Series::from_data(vec![ json!(4i64), @@ -383,7 +412,7 @@ fn test_variant_cast_function() -> Result<()> { expect: Series::from_data(vec![4i16, -3, 2, -4]), error: "", }), - (CastFunction::create("cast", "int32")?, ScalarFunctionTest { + ("toInt32", ScalarFunctionTest { name: "cast-variant-to-int32-passed", columns: vec![Series::from_data(vec![ json!(4i64), @@ -394,7 +423,7 @@ fn test_variant_cast_function() -> Result<()> { expect: Series::from_data(vec![4i32, -3, 2, -4]), error: "", }), - (CastFunction::create("cast", "int64")?, ScalarFunctionTest { + ("toInt64", ScalarFunctionTest { name: "cast-variant-to-int64-passed", columns: vec![Series::from_data(vec![ json!(4i64), @@ -405,7 +434,7 @@ fn test_variant_cast_function() -> Result<()> { expect: Series::from_data(vec![4i64, -3, 2, -4]), error: "", }), - (CastFunction::create("cast", "int64")?, ScalarFunctionTest { + ("toInt64", ScalarFunctionTest { name: "cast-variant-to-int64-error", columns: vec![Series::from_data(vec![ json!("X4"), @@ -416,190 +445,148 @@ fn test_variant_cast_function() -> Result<()> { expect: Series::from_data(vec![4i64, -3, 2, -4]), error: "Cast error happens in casting from Variant to Int64", }), - ( - CastFunction::create("cast", "float32")?, - ScalarFunctionTest { - name: "cast-variant-to-float32-passed", - columns: vec![Series::from_data(vec![ - json!(1.2f64), - json!(-1.3f64), - json!("2.1"), - json!("-4.2"), - ])], - expect: Series::from_data(vec![1.2f32, -1.3, 2.1, -4.2]), - error: "", - }, - ), - ( - CastFunction::create("cast", "float32")?, - ScalarFunctionTest { - name: "cast-variant-to-float32-error", - columns: vec![Series::from_data(vec![ - json!("X4"), - json!(-1.3f64), - json!("2.1"), - json!("-4.2"), - ])], - expect: Series::from_data(vec![1.2f32, -1.3, 2.1, -4.2]), - error: "Cast error happens in casting from Variant to Float32", - }, - ), - ( - CastFunction::create("cast", "float64")?, - ScalarFunctionTest { - name: "cast-variant-to-float64-passed", - columns: vec![Series::from_data(vec![ - json!(1.2f64), - json!(-1.3f64), - json!("2.1"), - json!("-4.2"), - ])], - expect: Series::from_data(vec![1.2f64, -1.3, 2.1, -4.2]), - error: "", - }, - ), - ( - CastFunction::create("cast", "float64")?, - ScalarFunctionTest { - name: "cast-variant-to-float64-error", - columns: vec![Series::from_data(vec![ - json!("X4"), - json!(-1.3f64), - json!("2.1"), - json!("-4.2"), - ])], - expect: Series::from_data(vec![1.2f64, -1.3, 2.1, -4.2]), - error: "Cast error happens in casting from Variant to Float64", - }, - ), - ( - CastFunction::create("cast", "boolean")?, - ScalarFunctionTest { - name: "cast-variant-to-boolean-passed", - columns: vec![Series::from_data(vec![ - json!(true), - json!(false), - json!("true"), - json!("false"), - ])], - expect: Series::from_data(vec![true, false, true, false]), - error: "", - }, - ), - ( - CastFunction::create("cast", "boolean")?, - ScalarFunctionTest { - name: "cast-variant-to-boolean-error", - columns: vec![Series::from_data(vec![ - json!(1), - json!("test"), - json!(true), - json!(false), - ])], - expect: Series::from_data(vec![true, false, true, false]), - error: "Cast error happens in casting from Variant to Boolean", - }, - ), - ( - CastFunction::create("cast", "date16")?, - ScalarFunctionTest { - name: "cast-variant-to-date16-passed", - columns: vec![Series::from_data(vec![ - json!("2021-03-05"), - json!("2021-10-24"), - ])], - expect: Series::from_data(vec![18691u16, 18924]), - error: "", - }, - ), - ( - CastFunction::create("cast", "date16")?, - ScalarFunctionTest { - name: "cast-variant-to-date16-error", - columns: vec![Series::from_data(vec![ - json!("a2021-03-05"), - json!("2021-10-24"), - ])], - expect: Series::from_data(vec![18691u16, 18924]), - error: "Cast error happens in casting from Variant to Date16", - }, - ), - ( - CastFunction::create("cast", "date32")?, - ScalarFunctionTest { - name: "cast-variant-to-date32-passed", - columns: vec![Series::from_data(vec![ - json!("2021-03-05"), - json!("2021-10-24"), - ])], - expect: Series::from_data(vec![18691i32, 18924]), - error: "", - }, - ), - ( - CastFunction::create("cast", "date32")?, - ScalarFunctionTest { - name: "cast-variant-to-date32-error", - columns: vec![Series::from_data(vec![ - json!("a2021-03-05"), - json!("2021-10-24"), - ])], - expect: Series::from_data(vec![18691i32, 18924]), - error: "Cast error happens in casting from Variant to Date32", - }, - ), - ( - CastFunction::create("cast", "datetime32")?, - ScalarFunctionTest { - name: "cast-variant-to-datetime32-passed", - columns: vec![Series::from_data(vec![ - json!("2021-03-05 01:01:01"), - json!("2021-10-24 10:10:10"), - ])], - expect: Series::from_data(vec![1614906061u32, 1635070210]), - error: "", - }, - ), - ( - CastFunction::create("cast", "datetime32")?, - ScalarFunctionTest { - name: "cast-variant-to-datetime32-error", - columns: vec![Series::from_data(vec![ - json!("a2021-03-05 01:01:01"), - json!("2021-10-24 10:10:10"), - ])], - expect: Series::from_data(vec![1614906061u32, 1635070210]), - error: "Cast error happens in casting from Variant to DateTime32", - }, - ), - ( - CastFunction::create("cast", "datetime64")?, - ScalarFunctionTest { - name: "cast-variant-to-datetime64-passed", - columns: vec![Series::from_data(vec![ - json!("2021-03-05 01:01:01.123"), - json!("2021-10-24 10:10:10.123"), - ])], - expect: Series::from_data(vec![1614906061123i64, 1635070210123]), - error: "", - }, - ), - ( - CastFunction::create("cast", "datetime64")?, - ScalarFunctionTest { - name: "cast-variant-to-datetime64-error", - columns: vec![Series::from_data(vec![ - json!("a2021-03-05 01:01:01.123"), - json!("2021-10-24 10:10:10.123456789"), - ])], - expect: Series::from_data(vec![1614906061123i64, 1635070210123]), - error: "Cast error happens in casting from Variant to DateTime64(3)", - }, - ), + ("toFloat32", ScalarFunctionTest { + name: "cast-variant-to-float32-passed", + columns: vec![Series::from_data(vec![ + json!(1.2f64), + json!(-1.3f64), + json!("2.1"), + json!("-4.2"), + ])], + expect: Series::from_data(vec![1.2f32, -1.3, 2.1, -4.2]), + error: "", + }), + ("toFloat32", ScalarFunctionTest { + name: "cast-variant-to-float32-error", + columns: vec![Series::from_data(vec![ + json!("X4"), + json!(-1.3f64), + json!("2.1"), + json!("-4.2"), + ])], + expect: Series::from_data(vec![1.2f32, -1.3, 2.1, -4.2]), + error: "Cast error happens in casting from Variant to Float32", + }), + ("toFloat64", ScalarFunctionTest { + name: "cast-variant-to-float64-passed", + columns: vec![Series::from_data(vec![ + json!(1.2f64), + json!(-1.3f64), + json!("2.1"), + json!("-4.2"), + ])], + expect: Series::from_data(vec![1.2f64, -1.3, 2.1, -4.2]), + error: "", + }), + ("toFloat64", ScalarFunctionTest { + name: "cast-variant-to-float64-error", + columns: vec![Series::from_data(vec![ + json!("X4"), + json!(-1.3f64), + json!("2.1"), + json!("-4.2"), + ])], + expect: Series::from_data(vec![1.2f64, -1.3, 2.1, -4.2]), + error: "Cast error happens in casting from Variant to Float64", + }), + ("toBoolean", ScalarFunctionTest { + name: "cast-variant-to-boolean-passed", + columns: vec![Series::from_data(vec![ + json!(true), + json!(false), + json!("true"), + json!("false"), + ])], + expect: Series::from_data(vec![true, false, true, false]), + error: "", + }), + ("toBoolean", ScalarFunctionTest { + name: "cast-variant-to-boolean-error", + columns: vec![Series::from_data(vec![ + json!(1), + json!("test"), + json!(true), + json!(false), + ])], + expect: Series::from_data(vec![true, false, true, false]), + error: "Cast error happens in casting from Variant to Boolean", + }), + ("toDate16", ScalarFunctionTest { + name: "cast-variant-to-date16-passed", + columns: vec![Series::from_data(vec![ + json!("2021-03-05"), + json!("2021-10-24"), + ])], + expect: Series::from_data(vec![18691u16, 18924]), + error: "", + }), + ("toDate16", ScalarFunctionTest { + name: "cast-variant-to-date16-error", + columns: vec![Series::from_data(vec![ + json!("a2021-03-05"), + json!("2021-10-24"), + ])], + expect: Series::from_data(vec![18691u16, 18924]), + error: "Cast error happens in casting from Variant to Date16", + }), + ("toDate32", ScalarFunctionTest { + name: "cast-variant-to-date32-passed", + columns: vec![Series::from_data(vec![ + json!("2021-03-05"), + json!("2021-10-24"), + ])], + expect: Series::from_data(vec![18691i32, 18924]), + error: "", + }), + ("toDate32", ScalarFunctionTest { + name: "cast-variant-to-date32-error", + columns: vec![Series::from_data(vec![ + json!("a2021-03-05"), + json!("2021-10-24"), + ])], + expect: Series::from_data(vec![18691i32, 18924]), + error: "Cast error happens in casting from Variant to Date32", + }), + ("toDateTime32", ScalarFunctionTest { + name: "cast-variant-to-datetime32-passed", + columns: vec![Series::from_data(vec![ + json!("2021-03-05 01:01:01"), + json!("2021-10-24 10:10:10"), + ])], + expect: Series::from_data(vec![1614906061u32, 1635070210]), + error: "", + }), + ("toDateTime32", ScalarFunctionTest { + name: "cast-variant-to-datetime32-error", + columns: vec![Series::from_data(vec![ + json!("a2021-03-05 01:01:01"), + json!("2021-10-24 10:10:10"), + ])], + expect: Series::from_data(vec![1614906061u32, 1635070210]), + error: "Cast error happens in casting from Variant to DateTime32", + }), + ("toDateTime64", ScalarFunctionTest { + name: "cast-variant-to-datetime64-passed", + columns: vec![Series::from_data(vec![ + json!("2021-03-05 01:01:01.123"), + json!("2021-10-24 10:10:10.123"), + ])], + expect: Series::from_data(vec![1614906061123i64, 1635070210123]), + error: "", + }), + ("toDateTime64", ScalarFunctionTest { + name: "cast-variant-to-datetime64-error", + columns: vec![Series::from_data(vec![ + json!("a2021-03-05 01:01:01.123"), + json!("2021-10-24 10:10:10.123456789"), + ])], + expect: Series::from_data(vec![1614906061123i64, 1635070210123]), + error: "Cast error happens in casting from Variant to DateTime64(3)", + }), ]; - for (test_func, test) in tests { - test_scalar_functions(test_func, &[test], false)?; + for (op, test) in tests { + test_scalar_functions(op, &[test])?; } let tests = vec![ @@ -650,7 +637,20 @@ fn test_variant_cast_function() -> Result<()> { ]; for (test_func, test) in tests { - test_scalar_functions_with_type(test_func, &[test], false)?; + match test_func.eval( + &test.columns, + test.columns[0].column().len(), + FunctionContext { tz: None }, + ) { + Ok(v) => { + let v = v.convert_full_column(); + + assert_eq!(test.expect, v, "{}", test.name); + } + Err(cause) => { + assert_eq!(test.error, cause.message(), "{}", test.name); + } + } } Ok(()) diff --git a/common/functions/tests/it/scalars/hashes.rs b/common/functions/tests/it/scalars/hashes.rs index 2e0cccb4a649..18e822915970 100644 --- a/common/functions/tests/it/scalars/hashes.rs +++ b/common/functions/tests/it/scalars/hashes.rs @@ -17,19 +17,11 @@ use std::hash::Hasher; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::Blake3HashFunction; -use common_functions::scalars::City64WithSeedFunction; -use common_functions::scalars::Md5HashFunction; -use common_functions::scalars::Sha1HashFunction; -use common_functions::scalars::Sha2HashFunction; -use common_functions::scalars::SipHash64Function; -use common_functions::scalars::XxHash32Function; -use common_functions::scalars::XxHash64Function; use naive_cityhash::cityhash64_with_seed; use twox_hash::XxHash32; -use super::scalar_function2_test::test_scalar_functions; -use super::scalar_function2_test::ScalarFunctionTest; +use super::scalar_function_test::test_scalar_functions; +use super::scalar_function_test::ScalarFunctionTest; #[test] fn test_siphash_function() -> Result<()> { @@ -136,7 +128,7 @@ fn test_siphash_function() -> Result<()> { }, ]; - test_scalar_functions(SipHash64Function::try_create("siphash")?, &tests, true) + test_scalar_functions("siphash64", &tests) } #[test] @@ -148,7 +140,7 @@ fn test_md5hash_function() -> Result<()> { error: "", }]; - test_scalar_functions(Md5HashFunction::try_create("md5")?, &tests, true) + test_scalar_functions("md5", &tests) } #[test] @@ -160,69 +152,69 @@ fn test_sha1hash_function() -> Result<()> { error: "", }]; - test_scalar_functions(Sha1HashFunction::try_create("sha1")?, &tests, true) + test_scalar_functions("sha1", &tests) } #[test] fn test_sha2hash_function() -> Result<()> { let tests = vec![ ScalarFunctionTest { - name: "Sha0 (256)", - columns: vec![Series::from_data(["abc"]), Series::from_data([0_u32])], - expect: Series::from_data(["ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"]), - error: "", - }, + name: "Sha0 (256)", + columns: vec![Series::from_data(["abc"]), Series::from_data([0_u32])], + expect: Series::from_data(["ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"]), + error: "", + }, ScalarFunctionTest { - name: "Sha224", - columns: vec![Series::from_data(["abc"]), Series::from_data([224_u32])], - expect: Series::from_data(["23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7"]), - error: "", - }, + name: "Sha224", + columns: vec![Series::from_data(["abc"]), Series::from_data([224_u32])], + expect: Series::from_data(["23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7"]), + error: "", + }, ScalarFunctionTest { - name: "Sha256", - columns: vec![Series::from_data(["abc"]), Series::from_data([256_u32])], - expect: Series::from_data(["ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"]), - error: "", - }, + name: "Sha256", + columns: vec![Series::from_data(["abc"]), Series::from_data([256_u32])], + expect: Series::from_data(["ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"]), + error: "", + }, ScalarFunctionTest { - name: "Sha384", - columns: vec![Series::from_data(["abc"]), Series::from_data([384_u32])], - expect: Series::from_data(["cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7"]), - error: "", - }, + name: "Sha384", + columns: vec![Series::from_data(["abc"]), Series::from_data([384_u32])], + expect: Series::from_data(["cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7"]), + error: "", + }, ScalarFunctionTest { - name: "Sha512", - columns: vec![Series::from_data(["abc"]), Series::from_data([512_u32])], - expect: Series::from_data(["ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"]), - error: "", - }, + name: "Sha512", + columns: vec![Series::from_data(["abc"]), Series::from_data([512_u32])], + expect: Series::from_data(["ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"]), + error: "", + }, ScalarFunctionTest { - name: "InvalidSha", - columns: vec![Series::from_data(["abc"]), Series::from_data([1_u32])], - expect: Series::from_data([Option::<&str>::None]), - error: "Expected [0, 224, 256, 384, 512] as sha2 encode options, but got 1", - }, + name: "InvalidSha", + columns: vec![Series::from_data(["abc"]), Series::from_data([1_u32])], + expect: Series::from_data([Option::<&str>::None]), + error: "Expected [0, 224, 256, 384, 512] as sha2 encode options, but got 1", + }, ScalarFunctionTest { - name: "Sha Length as Const Field", - columns: vec![ - Series::from_data(["abc"]), - Series::from_data([224_u16]), - ], - expect: Series::from_data(["23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7"]), - error: "", - }, + name: "Sha Length as Const Field", + columns: vec![ + Series::from_data(["abc"]), + Series::from_data([224_u16]), + ], + expect: Series::from_data(["23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7"]), + error: "", + }, ScalarFunctionTest { - name: "Sha Length with null value", - columns: vec![ - Series::from_data([Option::<&str>::None]), - Series::from_data([Option::::None]), - ], - expect: Series::from_data([Option::<&str>::None]), - error: "", - }, + name: "Sha Length with null value", + columns: vec![ + Series::from_data([Option::<&str>::None]), + Series::from_data([Option::::None]), + ], + expect: Series::from_data([Option::<&str>::None]), + error: "", + }, ]; - test_scalar_functions(Sha2HashFunction::try_create("sha2")?, &tests, true) + test_scalar_functions("sha2", &tests) } #[test] @@ -236,7 +228,7 @@ fn test_blake3hash_function() -> Result<()> { error: "", }]; - test_scalar_functions(Blake3HashFunction::try_create("blake3")?, &tests, true) + test_scalar_functions("blake3", &tests) } #[test] @@ -248,7 +240,7 @@ fn test_xxhash32_function() -> Result<()> { error: "", }]; - test_scalar_functions(XxHash32Function::try_create("xxhash32")?, &tests, true) + test_scalar_functions("xxhash32", &tests) } #[test] @@ -260,7 +252,7 @@ fn test_xxhash64_function() -> Result<()> { error: "", }]; - test_scalar_functions(XxHash64Function::try_create("xxhash64")?, &tests, true) + test_scalar_functions("xxhash64", &tests) } #[test] @@ -313,11 +305,7 @@ fn test_cityhash64_with_seed_u8() -> Result<()> { }; let tests = vec![test0, test1]; - test_scalar_functions( - City64WithSeedFunction::try_create("city64WithSeed")?, - &tests, - true, - ) + test_scalar_functions("city64WithSeed", &tests) } #[test] @@ -342,6 +330,8 @@ fn test_cityhash64_with_seed_string() -> Result<()> { error: "", }; + test_scalar_functions("city64WithSeed", &[test0])?; + let to_hash = vec![Some("Superman"), None, None]; let seed = 100u64; //constant seed let mut expected_result = Vec::with_capacity(to_hash.len()); @@ -364,10 +354,5 @@ fn test_cityhash64_with_seed_string() -> Result<()> { error: "", }; - let tests = vec![test0, test1]; - test_scalar_functions( - City64WithSeedFunction::try_create("city64WithSeed")?, - &tests, - true, - ) + test_scalar_functions("city64WithSeed", &[test1]) } diff --git a/common/functions/tests/it/scalars/logics.rs b/common/functions/tests/it/scalars/logics.rs index 9c0fbe941563..2e4032f7cb60 100644 --- a/common/functions/tests/it/scalars/logics.rs +++ b/common/functions/tests/it/scalars/logics.rs @@ -16,13 +16,9 @@ use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::LogicAndFunction; -use common_functions::scalars::LogicNotFunction; -use common_functions::scalars::LogicOrFunction; -use common_functions::scalars::LogicXorFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_logic_not_function() -> Result<()> { @@ -34,7 +30,7 @@ fn test_logic_not_function() -> Result<()> { error: "", }, ScalarFunctionTest { - name: "not-null", + name: "not-nullable", columns: vec![Series::from_data(vec![None, Some(true), Some(false)])], expect: Series::from_data(vec![None, Some(false), Some(true)]), error: "", @@ -46,7 +42,7 @@ fn test_logic_not_function() -> Result<()> { error: "", }, ]; - test_scalar_functions(LogicNotFunction::try_create("not")?, &tests, true) + test_scalar_functions("not", &tests) } #[test] @@ -73,14 +69,14 @@ fn test_logic_and_function() -> Result<()> { ScalarFunctionTest { name: "and-null", columns: vec![ - Series::from_data(vec![None, Some(true), Some(true), Some(false)]), - Arc::new(NullColumn::new(4)), + Series::from_data(vec![None, Some(true), Some(false)]), + Arc::new(NullColumn::new(3)), ], - expect: Arc::new(NullColumn::new(4)), + expect: Series::from_data(vec![None, None, Some(false)]), error: "", }, ]; - test_scalar_functions(LogicAndFunction::try_create("and")?, &tests, true) + test_scalar_functions("and", &tests) } #[test] @@ -98,19 +94,10 @@ fn test_logic_or_function() -> Result<()> { ScalarFunctionTest { name: "or-null", columns: vec![ - Series::from_data(vec![None, None, None, Some(false)]), - Series::from_data(vec![Some(true), Some(false), None, Some(true)]), + Series::from_data(vec![None, None, None, Some(false), Some(false)]), + Series::from_data(vec![Some(true), Some(false), None, Some(true), Some(false)]), ], - expect: Series::from_data(vec![Some(true), None, None, Some(true)]), - error: "", - }, - ScalarFunctionTest { - name: "or-null", - columns: vec![ - Series::from_data(vec![None, None, None, Some(false)]), - Series::from_data(vec![Some(true), Some(false), None, Some(true)]), - ], - expect: Series::from_data(vec![Some(true), None, None, Some(true)]), + expect: Series::from_data(vec![Some(true), None, None, Some(true), Some(false)]), error: "", }, ScalarFunctionTest { @@ -123,7 +110,8 @@ fn test_logic_or_function() -> Result<()> { error: "", }, ]; - test_scalar_functions(LogicOrFunction::try_create("or")?, &tests, false) + + test_scalar_functions("or", &tests) } #[test] @@ -157,5 +145,5 @@ fn test_logic_xor_function() -> Result<()> { error: "", }, ]; - test_scalar_functions(LogicXorFunction::try_create("or")?, &tests, true) + test_scalar_functions("xor", &tests) } diff --git a/common/functions/tests/it/scalars/maths/abs.rs b/common/functions/tests/it/scalars/maths/abs.rs index e2e79392930a..251bc70b6941 100644 --- a/common/functions/tests/it/scalars/maths/abs.rs +++ b/common/functions/tests/it/scalars/maths/abs.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_abs_function() -> Result<()> { @@ -48,5 +47,5 @@ fn test_abs_function() -> Result<()> { }, ]; - test_scalar_functions(AbsFunction::try_create("abs(false)")?, &tests, true) + test_scalar_functions("abs", &tests) } diff --git a/common/functions/tests/it/scalars/maths/angle.rs b/common/functions/tests/it/scalars/maths/angle.rs index bca358072930..6f88b37e743d 100644 --- a/common/functions/tests/it/scalars/maths/angle.rs +++ b/common/functions/tests/it/scalars/maths/angle.rs @@ -16,10 +16,9 @@ use std::f64::consts::PI; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_degress_function() -> Result<()> { @@ -30,7 +29,7 @@ fn test_degress_function() -> Result<()> { error: "", }]; - test_scalar_functions(DegressFunction::try_create("degrees")?, &tests, true) + test_scalar_functions("degrees", &tests) } #[test] @@ -42,5 +41,5 @@ fn test_radians_function() -> Result<()> { error: "", }]; - test_scalar_functions(RadiansFunction::try_create("radians")?, &tests, true) + test_scalar_functions("radians", &tests) } diff --git a/common/functions/tests/it/scalars/maths/ceil.rs b/common/functions/tests/it/scalars/maths/ceil.rs index 7c78e7529a1a..b572e9235e4d 100644 --- a/common/functions/tests/it/scalars/maths/ceil.rs +++ b/common/functions/tests/it/scalars/maths/ceil.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_ceil_function() -> Result<()> { @@ -78,5 +77,5 @@ fn test_ceil_function() -> Result<()> { }, ]; - test_scalar_functions(CeilFunction::try_create("ceil")?, &tests, true) + test_scalar_functions("ceil", &tests) } diff --git a/common/functions/tests/it/scalars/maths/crc32.rs b/common/functions/tests/it/scalars/maths/crc32.rs index 6e7640b7b782..e02224d329e5 100644 --- a/common/functions/tests/it/scalars/maths/crc32.rs +++ b/common/functions/tests/it/scalars/maths/crc32.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_crc32_function() -> Result<()> { @@ -42,5 +41,5 @@ fn test_crc32_function() -> Result<()> { }, ]; - test_scalar_functions(CRC32Function::try_create("crc")?, &tests, true) + test_scalar_functions("crc32", &tests) } diff --git a/common/functions/tests/it/scalars/maths/exp.rs b/common/functions/tests/it/scalars/maths/exp.rs index 9ea6e782f238..66e4f6cdb376 100644 --- a/common/functions/tests/it/scalars/maths/exp.rs +++ b/common/functions/tests/it/scalars/maths/exp.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_exp_function() -> Result<()> { @@ -36,5 +35,5 @@ fn test_exp_function() -> Result<()> { }, ]; - test_scalar_functions(ExpFunction::try_create("exp")?, &tests, true) + test_scalar_functions("exp", &tests) } diff --git a/common/functions/tests/it/scalars/maths/floor.rs b/common/functions/tests/it/scalars/maths/floor.rs index a00c64782c57..1795b5951964 100644 --- a/common/functions/tests/it/scalars/maths/floor.rs +++ b/common/functions/tests/it/scalars/maths/floor.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_floor_function() -> Result<()> { @@ -78,5 +77,5 @@ fn test_floor_function() -> Result<()> { }, ]; - test_scalar_functions(FloorFunction::try_create("floor")?, &tests, true) + test_scalar_functions("floor", &tests) } diff --git a/common/functions/tests/it/scalars/maths/log.rs b/common/functions/tests/it/scalars/maths/log.rs index c56002fe4423..15c2d1c62291 100644 --- a/common/functions/tests/it/scalars/maths/log.rs +++ b/common/functions/tests/it/scalars/maths/log.rs @@ -16,10 +16,9 @@ use std::f64::consts::E; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_log_function() -> Result<()> { @@ -88,7 +87,7 @@ fn test_log_function() -> Result<()> { }, ]; - test_scalar_functions(LogFunction::try_create("log")?, &tests, true) + test_scalar_functions("log", &tests) } #[test] @@ -108,7 +107,7 @@ fn test_ln_function() -> Result<()> { }, ]; - test_scalar_functions(LnFunction::try_create("ln")?, &tests, true) + test_scalar_functions("ln", &tests) } #[test] @@ -120,7 +119,7 @@ fn test_log2_function() -> Result<()> { error: "", }]; - test_scalar_functions(Log2Function::try_create("log2")?, &tests, true) + test_scalar_functions("log2", &tests) } #[test] @@ -132,5 +131,5 @@ fn test_log10_function() -> Result<()> { error: "", }]; - test_scalar_functions(Log10Function::try_create("log10")?, &tests, true) + test_scalar_functions("log10", &tests) } diff --git a/common/functions/tests/it/scalars/maths/pi.rs b/common/functions/tests/it/scalars/maths/pi.rs index b2ccd22e7bdf..e36a32af617f 100644 --- a/common/functions/tests/it/scalars/maths/pi.rs +++ b/common/functions/tests/it/scalars/maths/pi.rs @@ -16,10 +16,9 @@ use std::f64::consts::PI; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_pi_function() -> Result<()> { @@ -31,5 +30,5 @@ fn test_pi_function() -> Result<()> { error: "", }]; - test_scalar_functions(PiFunction::try_create("pi()")?, &tests, true) + test_scalar_functions("pi", &tests) } diff --git a/common/functions/tests/it/scalars/maths/pow.rs b/common/functions/tests/it/scalars/maths/pow.rs index d4abe91c4c1e..62962c12cd86 100644 --- a/common/functions/tests/it/scalars/maths/pow.rs +++ b/common/functions/tests/it/scalars/maths/pow.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_pow_function() -> Result<()> { @@ -63,5 +62,5 @@ fn test_pow_function() -> Result<()> { }, ]; - test_scalar_functions(PowFunction::try_create("pow")?, &tests, true) + test_scalar_functions("pow", &tests) } diff --git a/common/functions/tests/it/scalars/maths/round.rs b/common/functions/tests/it/scalars/maths/round.rs index 10cf37f1dba2..898da696de3b 100644 --- a/common/functions/tests/it/scalars/maths/round.rs +++ b/common/functions/tests/it/scalars/maths/round.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_round_number_function() -> Result<()> { @@ -156,7 +155,7 @@ fn test_round_number_function() -> Result<()> { }, ]; - test_scalar_functions(RoundNumberFunction::try_create("round")?, &tests, true) + test_scalar_functions("round", &tests) } #[test] @@ -206,7 +205,7 @@ fn test_trunc_number_function() -> Result<()> { name: "first arg is const, second is series", columns: vec![ ConstColumn::new(Series::from_data(vec![11.11f64]), 4).arc(), - Series::from_data([None, Some(-1), Some(0), Some(1)]), + Series::from_data([None, Some(-1i64), Some(0), Some(1)]), ], expect: Series::from_data([None, Some(10.0), Some(11.0), Some(11.1)]), error: "", @@ -222,12 +221,12 @@ fn test_trunc_number_function() -> Result<()> { Some(33.33), Some(44.44), ]), - Series::from_data([None, Some(1), None, Some(0), Some(-1), Some(1)]), + Series::from_data([None, Some(1i64), None, Some(0), Some(-1), Some(1)]), ], expect: Series::from_data([None, None, None, Some(22.0), Some(30.0), Some(44.4)]), error: "", }, ]; - test_scalar_functions(TruncNumberFunction::try_create("trunc")?, &tests, true) + test_scalar_functions("truncate", &tests) } diff --git a/common/functions/tests/it/scalars/maths/sign.rs b/common/functions/tests/it/scalars/maths/sign.rs index d5ef80517f19..77cee9f7784a 100644 --- a/common/functions/tests/it/scalars/maths/sign.rs +++ b/common/functions/tests/it/scalars/maths/sign.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_sign_function() -> Result<()> { @@ -108,7 +107,5 @@ fn test_sign_function() -> Result<()> { }, ]; - let sign_f = SignFunction::try_create("sign")?; - let sign_f = FunctionAdapter::create(sign_f, true); - test_scalar_functions(sign_f, &tests, true) + test_scalar_functions("sign", &tests) } diff --git a/common/functions/tests/it/scalars/maths/sqrt.rs b/common/functions/tests/it/scalars/maths/sqrt.rs index 25e846f17276..d223a43f7b28 100644 --- a/common/functions/tests/it/scalars/maths/sqrt.rs +++ b/common/functions/tests/it/scalars/maths/sqrt.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_sqrt_function() -> Result<()> { @@ -42,5 +41,5 @@ fn test_sqrt_function() -> Result<()> { }, ]; - test_scalar_functions(SqrtFunction::try_create("sqrt")?, &tests, true) + test_scalar_functions("sqrt", &tests) } diff --git a/common/functions/tests/it/scalars/maths/trigonometric.rs b/common/functions/tests/it/scalars/maths/trigonometric.rs index 2d6374cdd724..9f57005bfa78 100644 --- a/common/functions/tests/it/scalars/maths/trigonometric.rs +++ b/common/functions/tests/it/scalars/maths/trigonometric.rs @@ -18,10 +18,9 @@ use std::f64::consts::PI; use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_trigonometric_sin_function() -> Result<()> { @@ -58,11 +57,7 @@ fn test_trigonometric_sin_function() -> Result<()> { }, ]; - test_scalar_functions( - TrigonometricSinFunction::try_create_func("sin")?, - &tests, - true, - ) + test_scalar_functions("sin", &tests) } #[test] @@ -74,11 +69,7 @@ fn test_trigonometric_cos_function() -> Result<()> { error: "", }]; - test_scalar_functions( - TrigonometricCosFunction::try_create_func("cos")?, - &tests, - true, - ) + test_scalar_functions("cos", &tests) } #[test] @@ -90,11 +81,7 @@ fn test_trigonometric_tan_function() -> Result<()> { error: "", }]; - test_scalar_functions( - TrigonometricTanFunction::try_create_func("tan")?, - &tests, - true, - ) + test_scalar_functions("tan", &tests) } #[test] @@ -114,11 +101,7 @@ fn test_trigonometric_cot_function() -> Result<()> { }, ]; - test_scalar_functions( - TrigonometricCotFunction::try_create_func("cot")?, - &tests, - true, - ) + test_scalar_functions("cot", &tests) } #[test] @@ -130,11 +113,7 @@ fn test_trigonometric_asin_function() -> Result<()> { error: "", }]; - test_scalar_functions( - TrigonometricAsinFunction::try_create_func("asin")?, - &tests, - true, - ) + test_scalar_functions("asin", &tests) } #[test] @@ -146,11 +125,7 @@ fn test_trigonometric_acos_function() -> Result<()> { error: "", }]; - test_scalar_functions( - TrigonometricAcosFunction::try_create_func("acos")?, - &tests, - true, - ) + test_scalar_functions("acos", &tests) } #[test] @@ -173,11 +148,7 @@ fn test_trigonometric_atan_function() -> Result<()> { }, ]; - test_scalar_functions( - TrigonometricAtanFunction::try_create_func("atan")?, - &tests, - true, - ) + test_scalar_functions("atan", &tests) } #[test] @@ -212,9 +183,5 @@ fn test_trigonometric_atan2_function() -> Result<()> { }, ]; - test_scalar_functions( - TrigonometricAtan2Function::try_create_func("atan2")?, - &tests, - true, - ) + test_scalar_functions("atan2", &tests) } diff --git a/common/functions/tests/it/scalars/mod.rs b/common/functions/tests/it/scalars/mod.rs index 3e7fc9095d60..89b1aa4cfae9 100644 --- a/common/functions/tests/it/scalars/mod.rs +++ b/common/functions/tests/it/scalars/mod.rs @@ -22,7 +22,7 @@ mod logics; mod maths; mod nullables; mod others; -mod scalar_function2_test; +mod scalar_function_test; mod semi_structureds; mod strings; mod tuples; diff --git a/common/functions/tests/it/scalars/nullables.rs b/common/functions/tests/it/scalars/nullables.rs index 0c9cf439e609..8443307a676d 100644 --- a/common/functions/tests/it/scalars/nullables.rs +++ b/common/functions/tests/it/scalars/nullables.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use super::scalar_function2_test::test_scalar_functions; -use super::scalar_function2_test::ScalarFunctionTest; +use super::scalar_function_test::test_scalar_functions; +use super::scalar_function_test::ScalarFunctionTest; #[test] fn test_is_null_function() -> Result<()> { @@ -28,7 +27,7 @@ fn test_is_null_function() -> Result<()> { error: "", }]; - test_scalar_functions(IsNullFunction::try_create_func("")?, &tests, false) + test_scalar_functions("isNull", &tests) } #[test] @@ -40,5 +39,5 @@ fn test_is_not_null_function() -> Result<()> { error: "", }]; - test_scalar_functions(IsNotNullFunction::try_create_func("")?, &tests, false) + test_scalar_functions("isNotNull", &tests) } diff --git a/common/functions/tests/it/scalars/others.rs b/common/functions/tests/it/scalars/others.rs index 3a67025274d9..928d7a3834ad 100644 --- a/common/functions/tests/it/scalars/others.rs +++ b/common/functions/tests/it/scalars/others.rs @@ -12,24 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - +use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::InetAtonFunction; -use common_functions::scalars::InetNtoaFunction; -use common_functions::scalars::RunningDifferenceFunction; -use common_functions::scalars::TryInetAtonFunction; -use common_functions::scalars::TryInetNtoaFunction; -use super::scalar_function2_test::test_scalar_functions; -use super::scalar_function2_test::ScalarFunctionTest; -use crate::scalars::scalar_function2_test::test_scalar_functions_with_type; -use crate::scalars::scalar_function2_test::ScalarFunctionWithFieldTest; +use super::scalar_function_test::test_scalar_functions; +use super::scalar_function_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions_with_type; +use crate::scalars::scalar_function_test::ScalarFunctionWithFieldTest; #[test] fn test_running_difference_first_null() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "i8_first_null", @@ -225,23 +217,17 @@ fn test_running_difference_first_null() -> Result<()> { }, ]; - test_scalar_functions(RunningDifferenceFunction::try_create("a")?, &tests, false) + test_scalar_functions("runningDifference", &tests) } #[test] fn test_running_difference_datetime32_first_null() -> Result<()> { - use common_datavalues::prelude::*; - use common_datavalues::type_datetime32::DateTime32Type; - let tests = vec![ ScalarFunctionWithFieldTest { name: "datetime32_first_null", columns: vec![ColumnWithField::new( Series::from_data([None, Some(3_u32), None, Some(4), Some(10)]), - DataField::new( - "dummy_1", - Arc::new(NullableType::create(DateTime32Type::arc(None))), - ), + DataField::new("dummy_1", NullableType::arc(DateTime32Type::arc(None))), )], expect: Series::from_data([None, None, None, None, Some(6_i64)]), error: "", @@ -250,23 +236,18 @@ fn test_running_difference_datetime32_first_null() -> Result<()> { name: "datetime32_first_not_null", columns: vec![ColumnWithField::new( Series::from_data([Some(2_u32), Some(3), None, Some(4), Some(10)]), - DataField::new( - "dummy_1", - Arc::new(NullableType::create(DateTime32Type::arc(None))), - ), + DataField::new("dummy_1", NullableType::arc(DateTime32Type::arc(None))), )], expect: Series::from_data([Some(0_i64), Some(1), None, None, Some(6)]), error: "", }, ]; - test_scalar_functions_with_type(RunningDifferenceFunction::try_create("a")?, &tests, false) + test_scalar_functions_with_type("runningDifference", &tests) } #[test] fn test_try_inet_aton_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "valid input", @@ -288,14 +269,11 @@ fn test_try_inet_aton_function() -> Result<()> { }, ]; - let test_func = TryInetAtonFunction::try_create("try_inet_aton")?; - test_scalar_functions(test_func, &tests, true) + test_scalar_functions("try_inet_aton", &tests) } #[test] fn test_inet_aton_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "valid input", @@ -323,14 +301,11 @@ fn test_inet_aton_function() -> Result<()> { }, ]; - let test_func = InetAtonFunction::try_create("inet_aton")?; - test_scalar_functions(test_func, &tests, false) + test_scalar_functions("inet_aton", &tests) } #[test] fn test_try_inet_ntoa_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ // integer input test cases ScalarFunctionTest { @@ -369,18 +344,15 @@ fn test_try_inet_ntoa_function() -> Result<()> { name: "string_input_u32", columns: vec![Series::from_data(vec!["3232235777"])], expect: Series::from_data(vec![Some("192.168.1.1")]), - error: "Expected numeric or null type, but got String", + error: "Expected a numeric type, but got String", }, ]; - let test_func = TryInetNtoaFunction::try_create("try_inet_ntoa")?; - test_scalar_functions(test_func, &tests, true) + test_scalar_functions("try_inet_ntoa", &tests) } #[test] fn test_inet_ntoa_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ // integer input test cases ScalarFunctionTest { @@ -419,10 +391,21 @@ fn test_inet_ntoa_function() -> Result<()> { name: "string_input_empty", columns: vec![Series::from_data([""])], expect: Series::from_data([""]), - error: "Expected numeric or null type, but got String", + error: "Expected a numeric type, but got String", }, ]; - let test_func = InetNtoaFunction::try_create("inet_ntoa")?; - test_scalar_functions(test_func, &tests, true) + test_scalar_functions("inet_ntoa", &tests) +} + +#[test] +fn test_to_type_name_function() -> Result<()> { + let tests = vec![ScalarFunctionTest { + name: "to_type_name-example-passed", + columns: vec![Series::from_data([true, true, true, false])], + expect: Series::from_data(["Boolean", "Boolean", "Boolean", "Boolean"]), + error: "", + }]; + + test_scalar_functions("totypename", &tests) } diff --git a/common/functions/tests/it/scalars/scalar_function2_test.rs b/common/functions/tests/it/scalars/scalar_function_test.rs similarity index 75% rename from common/functions/tests/it/scalars/scalar_function2_test.rs rename to common/functions/tests/it/scalars/scalar_function_test.rs index 510147b88e48..220f5a8e5431 100644 --- a/common/functions/tests/it/scalars/scalar_function2_test.rs +++ b/common/functions/tests/it/scalars/scalar_function_test.rs @@ -14,9 +14,8 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::Function; -use common_functions::scalars::FunctionAdapter; use common_functions::scalars::FunctionContext; +use common_functions::scalars::FunctionFactory; use pretty_assertions::assert_eq; pub struct ScalarFunctionTest { @@ -33,11 +32,7 @@ pub struct ScalarFunctionWithFieldTest { pub error: &'static str, } -pub fn test_scalar_functions( - test_function: Box, - tests: &[ScalarFunctionTest], - passthrough_null: bool, -) -> Result<()> { +pub fn test_scalar_functions(op: &str, tests: &[ScalarFunctionTest]) -> Result<()> { let mut tests_with_type = Vec::with_capacity(tests.len()); for test in tests { let mut arguments = Vec::with_capacity(test.columns.len()); @@ -59,13 +54,12 @@ pub fn test_scalar_functions( }) } - test_scalar_functions_with_type(test_function, &tests_with_type, passthrough_null) + test_scalar_functions_with_type(op, &tests_with_type) } pub fn test_scalar_functions_with_type( - test_function: Box, + op: &str, tests: &[ScalarFunctionWithFieldTest], - passthrough_null: bool, ) -> Result<()> { for test in tests { let mut rows_size = 0; @@ -76,13 +70,7 @@ pub fn test_scalar_functions_with_type( rows_size = c.column().len(); } - match test_eval_with_type( - &test_function, - rows_size, - &test.columns, - &arguments_type, - passthrough_null, - ) { + match test_eval_with_type(op, rows_size, &test.columns, &arguments_type) { Ok(v) => { let v = v.convert_full_column(); @@ -98,11 +86,7 @@ pub fn test_scalar_functions_with_type( } #[allow(clippy::borrowed_box)] -pub fn test_eval( - test_function: &Box, - columns: &[ColumnRef], - passthrough_null: bool, -) -> Result { +pub fn test_eval(op: &str, columns: &[ColumnRef]) -> Result { let mut rows_size = 0; let mut arguments = Vec::with_capacity(columns.len()); let mut arguments_type = Vec::with_capacity(columns.len()); @@ -124,25 +108,18 @@ pub fn test_eval( types.push(t); } - test_eval_with_type( - test_function, - rows_size, - &arguments, - &types, - passthrough_null, - ) + test_eval_with_type(op, rows_size, &arguments, &types) } #[allow(clippy::borrowed_box)] pub fn test_eval_with_type( - test_function: &Box, + op: &str, rows_size: usize, arguments: &[ColumnWithField], arguments_type: &[&DataTypePtr], - passthrough_null: bool, ) -> Result { - let adaptor = FunctionAdapter::create(test_function.clone(), passthrough_null); - adaptor.return_type(arguments_type)?; + let func = FunctionFactory::instance().get(op, arguments_type)?; + func.return_type(); let func_ctx = FunctionContext { tz: None }; - adaptor.eval(arguments, rows_size, func_ctx) + func.eval(arguments, rows_size, func_ctx) } diff --git a/common/functions/tests/it/scalars/semi_structureds/check_json.rs b/common/functions/tests/it/scalars/semi_structureds/check_json.rs index fee0590208fa..f50749c09978 100644 --- a/common/functions/tests/it/scalars/semi_structureds/check_json.rs +++ b/common/functions/tests/it/scalars/semi_structureds/check_json.rs @@ -14,17 +14,15 @@ use std::sync::Arc; +use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::CheckJsonFunction; use serde_json::json; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_check_json_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "check_json_bool", @@ -203,5 +201,5 @@ fn test_check_json_function() -> Result<()> { }, ]; - test_scalar_functions(CheckJsonFunction::try_create("check_json")?, &tests, false) + test_scalar_functions("check_json", &tests) } diff --git a/common/functions/tests/it/scalars/semi_structureds/get.rs b/common/functions/tests/it/scalars/semi_structureds/get.rs index af571370a335..d7266d4d213e 100644 --- a/common/functions/tests/it/scalars/semi_structureds/get.rs +++ b/common/functions/tests/it/scalars/semi_structureds/get.rs @@ -12,19 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::GetFunction; -use common_functions::scalars::GetIgnoreCaseFunction; -use common_functions::scalars::GetPathFunction; use serde_json::json; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_get_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "get_by_field_name", @@ -70,13 +66,11 @@ fn test_get_function() -> Result<()> { }, ]; - test_scalar_functions(GetFunction::try_create("get")?, &tests, false) + test_scalar_functions("get", &tests) } #[test] fn test_get_ignore_case_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "get_by_field_name", @@ -101,17 +95,11 @@ fn test_get_ignore_case_function() -> Result<()> { }, ]; - test_scalar_functions( - GetIgnoreCaseFunction::try_create("get_ignore_case")?, - &tests, - false, - ) + test_scalar_functions("get_ignore_case", &tests) } #[test] fn test_get_path_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "get_by_path", @@ -139,5 +127,5 @@ fn test_get_path_function() -> Result<()> { }, ]; - test_scalar_functions(GetPathFunction::try_create("get_path")?, &tests, false) + test_scalar_functions("get_path", &tests) } diff --git a/common/functions/tests/it/scalars/semi_structureds/parse_json.rs b/common/functions/tests/it/scalars/semi_structureds/parse_json.rs index e744d7c673b7..864d033a297d 100644 --- a/common/functions/tests/it/scalars/semi_structureds/parse_json.rs +++ b/common/functions/tests/it/scalars/semi_structureds/parse_json.rs @@ -12,19 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::ParseJsonFunction; -use common_functions::scalars::TryParseJsonFunction; use serde_json::json; use serde_json::Value as JsonValue; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_parse_json_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "parse_json_bool", @@ -168,13 +165,11 @@ fn test_parse_json_function() -> Result<()> { }, ]; - test_scalar_functions(ParseJsonFunction::try_create("parse_json")?, &tests, false) + test_scalar_functions("parse_json", &tests) } #[test] fn test_try_parse_json_function() -> Result<()> { - use common_datavalues::prelude::*; - let tests = vec![ ScalarFunctionTest { name: "parse_json_bool", @@ -248,9 +243,5 @@ fn test_try_parse_json_function() -> Result<()> { }, ]; - test_scalar_functions( - TryParseJsonFunction::try_create("try_parse_json")?, - &tests, - false, - ) + test_scalar_functions("try_parse_json", &tests) } diff --git a/common/functions/tests/it/scalars/strings/locate.rs b/common/functions/tests/it/scalars/strings/locate.rs index 456144c5c2d6..b8034bec4e8c 100644 --- a/common/functions/tests/it/scalars/strings/locate.rs +++ b/common/functions/tests/it/scalars/strings/locate.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; use crate::scalars::scalar_function_test::test_scalar_functions; use crate::scalars::scalar_function_test::ScalarFunctionTest; @@ -24,114 +25,119 @@ fn test_locate_function() -> Result<()> { let tests = vec![ ScalarFunctionTest { name: "none, none, none", - nullable: true, columns: vec![ - Series::new([Option::<&str>::None]).into(), - Series::new([Option::<&str>::None]).into(), - Series::new([Option::<&str>::None]).into(), + Series::from_data([Option::<&str>::None]), + Series::from_data([Option::<&str>::None]), + Series::from_data([Option::::None]), ], - expect: Series::new([Option::::None]).into(), + expect: Series::from_data([Option::::None]), error: "", }, ScalarFunctionTest { name: "const, const, const", - nullable: false, columns: vec![ - DataColumn::Constant(DataValue::String(Some(b"ab".to_vec())), 1), - DataColumn::Constant(DataValue::String(Some(b"abcdabcd".to_vec())), 1), - DataColumn::Constant(DataValue::UInt64(Some(2)), 1), + Arc::new(ConstColumn::new(Series::from_data(vec![Some("ab")]), 1)), + Arc::new(ConstColumn::new( + Series::from_data(vec![Some("abcdabcd")]), + 1, + )), + Arc::new(ConstColumn::new(Series::from_data(vec![Some(2u64)]), 1)), ], - expect: DataColumn::Constant(DataValue::UInt64(Some(5)), 1), + expect: Arc::new(ConstColumn::new(Series::from_data(vec![Some(5u64)]), 1)), error: "", }, ScalarFunctionTest { name: "const, const, none", - nullable: false, columns: vec![ - DataColumn::Constant(DataValue::String(Some(b"ab".to_vec())), 1), - DataColumn::Constant(DataValue::String(Some(b"abcdabcd".to_vec())), 1), + Arc::new(ConstColumn::new(Series::from_data(vec![Some("ab")]), 1)), + Arc::new(ConstColumn::new( + Series::from_data(vec![Some("abcdabcd")]), + 1, + )), ], - expect: DataColumn::Constant(DataValue::UInt64(Some(1)), 1), + expect: Arc::new(ConstColumn::new(Series::from_data(vec![Some(1u64)]), 1)), error: "", }, ScalarFunctionTest { name: "series, series, const", - nullable: false, columns: vec![ - Series::new(["abcd", "efgh"]).into(), - Series::new(["_abcd_", "__efgh__"]).into(), - DataColumn::Constant(DataValue::UInt64(Some(1)), 1), + Series::from_data(["abcd", "efgh"]), + Series::from_data(["_abcd_", "__efgh__"]), + Arc::new(ConstColumn::new(Series::from_data(vec![Some(1u64)]), 2)), ], - expect: Series::new([2_u64, 3_u64]).into(), + expect: Series::from_data([Some(2_u64), Some(3_u64)]), error: "", }, ScalarFunctionTest { name: "const, series, const", - nullable: false, columns: vec![ - DataColumn::Constant(DataValue::String(Some(b"11".to_vec())), 1), - DataColumn::Array(Series::new(["_11_", "__11__"])), - DataColumn::Constant(DataValue::UInt64(Some(1)), 1), + Arc::new(ConstColumn::new(Series::from_data(vec![Some("11")]), 2)), + Series::from_data(["_11_", "__11__"]), + Arc::new(ConstColumn::new(Series::from_data(vec![Some(1u64)]), 2)), ], - expect: Series::new([2_u64, 3_u64]).into(), + expect: Series::from_data([Some(2_u64), Some(3_u64)]), error: "", }, ScalarFunctionTest { name: "series, const, const", - nullable: false, columns: vec![ - DataColumn::Array(Series::new(["11", "22"])), - DataColumn::Constant(DataValue::String(Some(b"_11_22_".to_vec())), 1), - DataColumn::Constant(DataValue::UInt64(Some(1)), 1), + Series::from_data(["11", "22"]), + Arc::new(ConstColumn::new( + Series::from_data(vec![Some("_11_22_")]), + 2, + )), + Arc::new(ConstColumn::new(Series::from_data(vec![Some(1u64)]), 2)), ], - expect: Series::new([2_u64, 5_u64]).into(), + expect: Series::from_data([Some(2_u64), Some(5_u64)]), error: "", }, ScalarFunctionTest { name: "const, const, series", - nullable: false, columns: vec![ - DataColumn::Constant(DataValue::String(Some(b"11".to_vec())), 1), - DataColumn::Constant(DataValue::String(Some(b"_11_11_".to_vec())), 1), - DataColumn::Array(Series::new([1_u64, 3_u64])), + Arc::new(ConstColumn::new(Series::from_data(vec![Some("11")]), 2)), + Arc::new(ConstColumn::new( + Series::from_data(vec![Some("_11_11_")]), + 2, + )), + Series::from_data([1_u64, 3_u64]), ], - expect: Series::new([2_u64, 5_u64]).into(), + expect: Series::from_data([Some(2_u64), Some(5_u64)]), error: "", }, ScalarFunctionTest { name: "series, const, series", - nullable: false, columns: vec![ - DataColumn::Array(Series::new(["11", "22"])), - DataColumn::Constant(DataValue::String(Some(b"_11_22_".to_vec())), 1), - DataColumn::Array(Series::new([1_u64, 3_u64])), + Series::from_data(["11", "22"]), + Arc::new(ConstColumn::new( + Series::from_data(vec![Some("_11_22_")]), + 2, + )), + Series::from_data([1_u64, 3_u64]), ], - expect: Series::new([2_u64, 5_u64]).into(), + expect: Series::from_data([Some(2_u64), Some(5_u64)]), error: "", }, ScalarFunctionTest { name: "const, series, series", - nullable: false, columns: vec![ - DataColumn::Constant(DataValue::String(Some(b"11".to_vec())), 1), - DataColumn::Array(Series::new(["_11_", "__11__"])), - DataColumn::Array(Series::new([1_u64, 2_u64])), + Arc::new(ConstColumn::new(Series::from_data(vec![Some("11")]), 2)), + Series::from_data(["_11_", "__11__"]), + Series::from_data([1_u64, 2_u64]), ], - expect: Series::new([2_u64, 3_u64]).into(), + expect: Series::from_data([Some(2_u64), Some(3_u64)]), error: "", }, ScalarFunctionTest { name: "series, series, series", - nullable: false, columns: vec![ - DataColumn::Array(Series::new(["11", "22"])), - DataColumn::Array(Series::new(["_11_", "__22__"])), - DataColumn::Array(Series::new([1_u64, 2_u64])), + Series::from_data(["11", "22"]), + Series::from_data(["_11_", "__22__"]), + Series::from_data([1_u64, 2_u64]), ], - expect: Series::new([2_u64, 3_u64]).into(), + expect: Series::from_data([2_u64, 3_u64]), error: "", }, ]; - test_scalar_functions(LocateFunction::try_create("locate")?, &tests, true) + test_scalar_functions("locate", &tests) } diff --git a/common/functions/tests/it/scalars/strings/lower.rs b/common/functions/tests/it/scalars/strings/lower.rs index be9ce23b1961..f8f0e08893ab 100644 --- a/common/functions/tests/it/scalars/strings/lower.rs +++ b/common/functions/tests/it/scalars/strings/lower.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::LowerFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_lower_function() -> Result<()> { @@ -42,7 +41,7 @@ fn test_lower_function() -> Result<()> { }, ]; - test_scalar_functions(LowerFunction::try_create("lower")?, &tests, true) + test_scalar_functions("lower", &tests) } #[test] @@ -54,5 +53,5 @@ fn test_lower_nullable() -> Result<()> { error: "", }]; - test_scalar_functions(LowerFunction::try_create("lcase")?, &tests, true) + test_scalar_functions("lcase", &tests) } diff --git a/common/functions/tests/it/scalars/strings/mod.rs b/common/functions/tests/it/scalars/strings/mod.rs index 74b31c895606..3fdd46b9c4b1 100644 --- a/common/functions/tests/it/scalars/strings/mod.rs +++ b/common/functions/tests/it/scalars/strings/mod.rs @@ -13,11 +13,11 @@ // limitations under the License. // mod locate; +mod locate; mod lower; mod regexp_instr; mod regexp_like; mod regexp_substr; mod substring; mod trim; - mod upper; diff --git a/common/functions/tests/it/scalars/strings/regexp_instr.rs b/common/functions/tests/it/scalars/strings/regexp_instr.rs index a35ae2931c7b..a1866e386688 100644 --- a/common/functions/tests/it/scalars/strings/regexp_instr.rs +++ b/common/functions/tests/it/scalars/strings/regexp_instr.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::RegexpInStrFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_regexp_instr_function() -> Result<()> { @@ -141,18 +140,14 @@ fn test_regexp_instr_function() -> Result<()> { }, ]; - test_scalar_functions( - RegexpInStrFunction::try_create("regexp_instr")?, - &tests, - true, - ) + test_scalar_functions("regexp_instr", &tests) } #[test] fn test_regexp_instr_constant_column() -> Result<()> { - let data_type = DataValue::String("dog".as_bytes().into()); - let data_value1 = StringType::arc().create_constant_column(&data_type, 3)?; - let data_value2 = StringType::arc().create_constant_column(&data_type, 3)?; + let data = DataValue::String("dog".as_bytes().into()); + let data_value1 = StringType::arc().create_constant_column(&data, 3)?; + let data_value2 = StringType::arc().create_constant_column(&data, 3)?; let tests = vec![ ScalarFunctionTest { @@ -181,9 +176,5 @@ fn test_regexp_instr_constant_column() -> Result<()> { }, ]; - test_scalar_functions( - RegexpInStrFunction::try_create("regexp_instr")?, - &tests, - true, - ) + test_scalar_functions("regexp_instr", &tests) } diff --git a/common/functions/tests/it/scalars/strings/regexp_like.rs b/common/functions/tests/it/scalars/strings/regexp_like.rs index 1b40a1427abf..3589085b8fbb 100644 --- a/common/functions/tests/it/scalars/strings/regexp_like.rs +++ b/common/functions/tests/it/scalars/strings/regexp_like.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::RegexpLikeFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_regexp_like_function() -> Result<()> { @@ -82,7 +81,7 @@ fn test_regexp_like_function() -> Result<()> { }, ]; - test_scalar_functions(RegexpLikeFunction::try_create("regexp_like")?, &tests, true) + test_scalar_functions("regexp_like", &tests) } #[test] @@ -110,5 +109,5 @@ fn test_regexp_like_match_type_joiner() -> Result<()> { }, ]; - test_scalar_functions(RegexpLikeFunction::try_create("regexp_like")?, &tests, true) + test_scalar_functions("regexp_like", &tests) } diff --git a/common/functions/tests/it/scalars/strings/regexp_substr.rs b/common/functions/tests/it/scalars/strings/regexp_substr.rs index 76474b6d71f8..8a31c8da4154 100644 --- a/common/functions/tests/it/scalars/strings/regexp_substr.rs +++ b/common/functions/tests/it/scalars/strings/regexp_substr.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::RegexpSubStrFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_regexp_instr_function() -> Result<()> { @@ -93,11 +92,7 @@ fn test_regexp_instr_function() -> Result<()> { }, ]; - test_scalar_functions( - RegexpSubStrFunction::try_create("regexp_substr")?, - &tests, - true, - ) + test_scalar_functions("regexp_substr", &tests) } #[test] @@ -134,9 +129,5 @@ fn test_regexp_substr_constant_column() -> Result<()> { }, ]; - test_scalar_functions( - RegexpSubStrFunction::try_create("regexp_substr")?, - &tests, - true, - ) + test_scalar_functions("regexp_substr", &tests) } diff --git a/common/functions/tests/it/scalars/strings/substring.rs b/common/functions/tests/it/scalars/strings/substring.rs index d81dad63c531..6ee279e22fa9 100644 --- a/common/functions/tests/it/scalars/strings/substring.rs +++ b/common/functions/tests/it/scalars/strings/substring.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::SubstringFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_substring_function() -> Result<()> { @@ -73,7 +72,7 @@ fn test_substring_function() -> Result<()> { }, ]; - test_scalar_functions(SubstringFunction::try_create("substring")?, &tests, true) + test_scalar_functions("substring", &tests) } #[test] @@ -89,5 +88,5 @@ fn test_substring_nullable() -> Result<()> { error: "", }]; - test_scalar_functions(SubstringFunction::try_create("substring")?, &tests, true) + test_scalar_functions("substr", &tests) } diff --git a/common/functions/tests/it/scalars/strings/trim.rs b/common/functions/tests/it/scalars/strings/trim.rs index f029c941815f..c28abf60d245 100644 --- a/common/functions/tests/it/scalars/strings/trim.rs +++ b/common/functions/tests/it/scalars/strings/trim.rs @@ -14,12 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::LTrimFunction; -use common_functions::scalars::RTrimFunction; -use common_functions::scalars::TrimFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_ltrim_function() -> Result<()> { @@ -30,7 +27,7 @@ fn test_ltrim_function() -> Result<()> { error: "", }]; - test_scalar_functions(LTrimFunction::try_create("ltrim")?, &tests, true) + test_scalar_functions("ltrim", &tests) } #[test] @@ -42,7 +39,7 @@ fn test_rtrim_function() -> Result<()> { error: "", }]; - test_scalar_functions(RTrimFunction::try_create("rtrim")?, &tests, true) + test_scalar_functions("rtrim", &tests) } #[test] @@ -62,7 +59,7 @@ fn test_trim_function() -> Result<()> { }, ]; - test_scalar_functions(TrimFunction::try_create("trim")?, &tests, true) + test_scalar_functions("trim", &tests) } #[test] @@ -74,5 +71,5 @@ fn test_trim_nullable() -> Result<()> { error: "", }]; - test_scalar_functions(TrimFunction::try_create("trim")?, &tests, true) + test_scalar_functions("trim", &tests) } diff --git a/common/functions/tests/it/scalars/strings/upper.rs b/common/functions/tests/it/scalars/strings/upper.rs index 72a65dacc74c..58f18b431e9b 100644 --- a/common/functions/tests/it/scalars/strings/upper.rs +++ b/common/functions/tests/it/scalars/strings/upper.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::UpperFunction; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_upper_function() -> Result<()> { @@ -42,7 +41,7 @@ fn test_upper_function() -> Result<()> { }, ]; - test_scalar_functions(UpperFunction::try_create("upper")?, &tests, true) + test_scalar_functions("upper", &tests) } #[test] @@ -54,5 +53,5 @@ fn test_upper_nullable() -> Result<()> { error: "", }]; - test_scalar_functions(UpperFunction::try_create("ucase")?, &tests, true) + test_scalar_functions("ucase", &tests) } diff --git a/common/functions/tests/it/scalars/tuples.rs b/common/functions/tests/it/scalars/tuples.rs index 7b4e9e9eb7a0..c0391e844a60 100644 --- a/common/functions/tests/it/scalars/tuples.rs +++ b/common/functions/tests/it/scalars/tuples.rs @@ -14,43 +14,42 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use super::scalar_function2_test::test_eval; -use super::scalar_function2_test::ScalarFunctionTest; +use super::scalar_function_test::test_eval; +use super::scalar_function_test::ScalarFunctionTest; #[test] fn test_tuple_function() -> Result<()> { let tests = vec![ - ScalarFunctionTest { - name: "one element to tuple", - columns: vec![Series::from_data([0_u8])], - expect: Series::from_data([0_u8]), - error: "", - }, - ScalarFunctionTest { - name: "more element to tuple", - columns: vec![Series::from_data([0_u8]), Series::from_data([0_u8])], - expect: Series::from_data([0_u8]), - error: "", - }, + ( + vec![DataValue::Struct(vec![DataValue::UInt64(0)])], + ScalarFunctionTest { + name: "one element to tuple", + columns: vec![Series::from_data([0_u8])], + expect: Series::from_data([0_u8]), + error: "", + }, + ), + ( + vec![DataValue::Struct(vec![ + DataValue::UInt64(0), + DataValue::UInt64(0), + ])], + ScalarFunctionTest { + name: "more element to tuple", + columns: vec![Series::from_data([0_u8]), Series::from_data([0_u8])], + expect: Series::from_data([0_u8]), + error: "", + }, + ), ]; - let v1 = vec![DataValue::Struct(vec![DataValue::UInt64(0)])]; - let v2 = vec![DataValue::Struct(vec![ - DataValue::UInt64(0), - DataValue::UInt64(0), - ])]; - - let values = vec![v1, v2]; - - for (t, v) in tests.iter().zip(values.iter()) { - let func = TupleFunction::try_create_func("")?; - let result = test_eval(&func, &t.columns, false)?; + for (val, test) in tests.iter() { + let result = test_eval("tuple", &test.columns)?; let result = result.convert_full_column(); let result = (0..result.len()).map(|i| result.get(i)).collect::>(); - assert!(&result == v) + assert!(&result == val) } Ok(()) diff --git a/common/functions/tests/it/scalars/udfs/database.rs b/common/functions/tests/it/scalars/udfs/database.rs index 98a9bbe28715..c14159af94db 100644 --- a/common/functions/tests/it/scalars/udfs/database.rs +++ b/common/functions/tests/it/scalars/udfs/database.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_database_function() -> Result<()> { @@ -28,5 +27,5 @@ fn test_database_function() -> Result<()> { error: "", }]; - test_scalar_functions(DatabaseFunction::try_create("database")?, &tests, true) + test_scalar_functions("database", &tests) } diff --git a/common/functions/tests/it/scalars/udfs/mod.rs b/common/functions/tests/it/scalars/udfs/mod.rs index 37122d5fb98b..605168b490aa 100644 --- a/common/functions/tests/it/scalars/udfs/mod.rs +++ b/common/functions/tests/it/scalars/udfs/mod.rs @@ -13,5 +13,4 @@ // limitations under the License. mod database; -mod to_type_name; mod version; diff --git a/common/functions/tests/it/scalars/udfs/to_type_name.rs b/common/functions/tests/it/scalars/udfs/to_type_name.rs deleted file mode 100644 index 38f62424ca50..000000000000 --- a/common/functions/tests/it/scalars/udfs/to_type_name.rs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2021 Datafuse Labs. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use common_datavalues::prelude::*; -use common_exception::Result; -use common_functions::scalars::*; - -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; - -#[test] -fn test_to_type_name_function() -> Result<()> { - let tests = vec![ScalarFunctionTest { - name: "to_type_name-example-passed", - columns: vec![Series::from_data([true, true, true, false])], - expect: Series::from_data(["Boolean", "Boolean", "Boolean", "Boolean"]), - error: "", - }]; - - test_scalar_functions(ToTypeNameFunction::try_create("toTypeName")?, &tests, false) -} diff --git a/common/functions/tests/it/scalars/udfs/version.rs b/common/functions/tests/it/scalars/udfs/version.rs index c7076c2d6e4d..53747ee68f39 100644 --- a/common/functions/tests/it/scalars/udfs/version.rs +++ b/common/functions/tests/it/scalars/udfs/version.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_version_function() -> Result<()> { @@ -28,5 +27,5 @@ fn test_version_function() -> Result<()> { error: "", }]; - test_scalar_functions(VersionFunction::try_create("version")?, &tests, true) + test_scalar_functions("version", &tests) } diff --git a/common/functions/tests/it/scalars/uuids/uuid_creator.rs b/common/functions/tests/it/scalars/uuids/uuid_creator.rs index 6791abe6c5e6..e74af364d580 100644 --- a/common/functions/tests/it/scalars/uuids/uuid_creator.rs +++ b/common/functions/tests/it/scalars/uuids/uuid_creator.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_uuid_creator_functions() -> Result<()> { @@ -40,5 +39,5 @@ fn test_uuid_creator_functions() -> Result<()> { error: "", }]; - test_scalar_functions(UUIDZeroFunction::try_create("")?, &tests, true) + test_scalar_functions("zeroUUID", &tests) } diff --git a/common/functions/tests/it/scalars/uuids/uuid_verifier.rs b/common/functions/tests/it/scalars/uuids/uuid_verifier.rs index 8f156896736a..219d7a688510 100644 --- a/common/functions/tests/it/scalars/uuids/uuid_verifier.rs +++ b/common/functions/tests/it/scalars/uuids/uuid_verifier.rs @@ -14,10 +14,9 @@ use common_datavalues::prelude::*; use common_exception::Result; -use common_functions::scalars::*; -use crate::scalars::scalar_function2_test::test_scalar_functions; -use crate::scalars::scalar_function2_test::ScalarFunctionTest; +use crate::scalars::scalar_function_test::test_scalar_functions; +use crate::scalars::scalar_function_test::ScalarFunctionTest; #[test] fn test_uuid_is_empty_functions() -> Result<()> { @@ -31,7 +30,7 @@ fn test_uuid_is_empty_functions() -> Result<()> { error: "", }]; - test_scalar_functions(UUIDIsEmptyFunction::try_create("")?, &tests, false) + test_scalar_functions("isemptyUUID", &tests) } #[test] @@ -46,5 +45,5 @@ fn test_uuid_is_not_empty_functions() -> Result<()> { error: "", }]; - test_scalar_functions(UUIDIsNotEmptyFunction::try_create("")?, &tests, false) + test_scalar_functions("isnotemptyUUID", &tests) } diff --git a/common/planners/src/plan_expression_chain.rs b/common/planners/src/plan_expression_chain.rs index 636a1c90c5a7..1e0fe3e20351 100644 --- a/common/planners/src/plan_expression_chain.rs +++ b/common/planners/src/plan_expression_chain.rs @@ -127,7 +127,7 @@ impl ExpressionChain { let arg_types = vec![nested_expr.to_data_type(&self.schema)?]; let arg_types2: Vec<&DataTypePtr> = arg_types.iter().collect(); let func = FunctionFactory::instance().get(op, &arg_types2)?; - let return_type = func.return_type(&arg_types2)?; + let return_type = func.return_type(); let function = ActionFunction { name: expr.column_name(), @@ -149,7 +149,7 @@ impl ExpressionChain { let arg_types2: Vec<&DataTypePtr> = arg_types.iter().collect(); let func = FunctionFactory::instance().get(op, &arg_types2)?; - let return_type = func.return_type(&arg_types2)?; + let return_type = func.return_type(); let function = ActionFunction { name: expr.column_name(), @@ -172,7 +172,7 @@ impl ExpressionChain { let arg_types2: Vec<&DataTypePtr> = arg_types.iter().collect(); let func = FunctionFactory::instance().get(op, &arg_types2)?; - let return_type = func.return_type(&arg_types2)?; + let return_type = func.return_type(); let function = ActionFunction { name: expr.column_name(), diff --git a/common/planners/src/plan_expression_common.rs b/common/planners/src/plan_expression_common.rs index 261c9941bb8e..ed3c2293f462 100644 --- a/common/planners/src/plan_expression_common.rs +++ b/common/planners/src/plan_expression_common.rs @@ -409,7 +409,7 @@ impl ExpressionDataTypeVisitor { let arguments: Vec<&DataTypePtr> = arguments.iter().collect(); let function = FunctionFactory::instance().get(op, &arguments)?; - let return_type = function.return_type(&arguments)?; + let return_type = function.return_type(); self.stack.push(return_type); Ok(self) } diff --git a/common/planners/src/plan_expression_monotonicity.rs b/common/planners/src/plan_expression_monotonicity.rs index a833a882c508..dadc984091da 100644 --- a/common/planners/src/plan_expression_monotonicity.rs +++ b/common/planners/src/plan_expression_monotonicity.rs @@ -135,7 +135,7 @@ impl ExpressionMonotonicityVisitor { let arg_types: Vec<&DataTypePtr> = arg_types.iter().collect(); let func = instance.get(op, &arg_types)?; - let return_type = func.return_type(&arg_types)?; + let return_type = func.return_type(); let mut monotonic = match self.single_point { false => func.get_monotonicity(monotonicity_vec.as_ref())?, true => { diff --git a/query/src/sql/statements/analyzer_expr.rs b/query/src/sql/statements/analyzer_expr.rs index f5ed9ed884b7..38947d655dfa 100644 --- a/query/src/sql/statements/analyzer_expr.rs +++ b/query/src/sql/statements/analyzer_expr.rs @@ -541,7 +541,7 @@ impl ExprRPNBuilder { Expr::TryCast { data_type, .. } => { let mut ty = SQLCommon::make_data_type(data_type)?; if ty.can_inside_nullable() { - ty = Arc::new(NullableType::create(ty)) + ty = NullableType::arc(ty) } self.rpn.push(ExprRPNItem::Cast(ty)); } diff --git a/tests/suites/0_stateless/02_function/02_0026_function_string_substring_index.sql b/tests/suites/0_stateless/02_function/02_0026_function_string_substring_index.sql index 548409be23c1..06cdd0f446a6 100644 --- a/tests/suites/0_stateless/02_function/02_0026_function_string_substring_index.sql +++ b/tests/suites/0_stateless/02_function/02_0026_function_string_substring_index.sql @@ -35,4 +35,3 @@ SELECT SUBSTRING_INDEX(number + 10, number, 1) FROM numbers(5) ORDER BY number; SELECT '=== series, series, series ==='; SELECT SUBSTRING_INDEX(number + 10, number, number) FROM numbers(5) ORDER BY number; - diff --git a/tests/suites/0_stateless/02_function/02_0046_function_logic.result b/tests/suites/0_stateless/02_function/02_0046_function_logic.result index 14e256742e62..dbf511aa78c6 100644 --- a/tests/suites/0_stateless/02_function/02_0046_function_logic.result +++ b/tests/suites/0_stateless/02_function/02_0046_function_logic.result @@ -4,6 +4,8 @@ 0 1 NULL +0 +NULL 6 7 1 @@ -13,8 +15,9 @@ NULL 1 0 1 +NULL 1 -0 +NULL 0 1 8 diff --git a/tests/suites/0_stateless/02_function/02_0046_function_logic.sql b/tests/suites/0_stateless/02_function/02_0046_function_logic.sql index 8aefddcde2a1..64cc9bb927d6 100644 --- a/tests/suites/0_stateless/02_function/02_0046_function_logic.sql +++ b/tests/suites/0_stateless/02_function/02_0046_function_logic.sql @@ -5,6 +5,8 @@ SELECT false and false; SELECT 1 and 0; SELECT 1 and 1; SELECT 1 and null; +SELECT 0 and null; +SELECT null and null; SELECT number from numbers(10) WHERE number > 5 AND number < 8 ORDER BY number; -- or, result: [line9, line21] SELECT true OR false; @@ -14,6 +16,7 @@ SELECT 1 OR 0; SELECT 1 OR 1; SELECT 0 OR 0; SELECT 1 OR null; +SELECT 0 OR null; SELECT null OR 1; SELECT null OR null; SELECT number from numbers(10) WHERE number > 7 OR number < 2 ORDER BY number;