diff --git a/query/src/pipelines/new/processors/transforms/transform_apply.rs b/query/src/pipelines/new/processors/transforms/transform_apply.rs index 8eb552d58c667..fb1454463559d 100644 --- a/query/src/pipelines/new/processors/transforms/transform_apply.rs +++ b/query/src/pipelines/new/processors/transforms/transform_apply.rs @@ -150,12 +150,14 @@ impl OuterRefRewriter { Scalar::AndExpr(expr) => Ok(AndExpr { left: Box::new(self.rewrite_scalar(&expr.left)?), right: Box::new(self.rewrite_scalar(&expr.right)?), + return_type: expr.return_type.clone(), } .into()), Scalar::OrExpr(expr) => Ok(OrExpr { left: Box::new(self.rewrite_scalar(&expr.left)?), right: Box::new(self.rewrite_scalar(&expr.right)?), + return_type: expr.return_type.clone(), } .into()), @@ -163,6 +165,7 @@ impl OuterRefRewriter { op: expr.op.clone(), left: Box::new(self.rewrite_scalar(&expr.left)?), right: Box::new(self.rewrite_scalar(&expr.right)?), + return_type: expr.return_type.clone(), } .into()), diff --git a/query/src/sql/exec/expression_builder.rs b/query/src/sql/exec/expression_builder.rs index 088188c54c330..a81ed10e6a839 100644 --- a/query/src/sql/exec/expression_builder.rs +++ b/query/src/sql/exec/expression_builder.rs @@ -57,9 +57,9 @@ impl ExpressionBuilder { Scalar::ConstantExpr(ConstantExpr { value, data_type }) => { self.build_literal(value, data_type) } - Scalar::ComparisonExpr(ComparisonExpr { op, left, right }) => { - self.build_binary_operator(left, right, op.to_func_name()) - } + Scalar::ComparisonExpr(ComparisonExpr { + op, left, right, .. + }) => self.build_binary_operator(left, right, op.to_func_name()), Scalar::AggregateFunction(AggregateFunction { func_name, distinct, @@ -67,7 +67,7 @@ impl ExpressionBuilder { args, .. }) => self.build_aggr_function(func_name.clone(), *distinct, params.clone(), args), - Scalar::AndExpr(AndExpr { left, right }) => { + Scalar::AndExpr(AndExpr { left, right, .. }) => { let left = self.build(&**left)?; let right = self.build(&**right)?; Ok(Expression::BinaryExpression { @@ -76,7 +76,7 @@ impl ExpressionBuilder { right: Box::new(right), }) } - Scalar::OrExpr(OrExpr { left, right }) => { + Scalar::OrExpr(OrExpr { left, right, .. }) => { let left = self.build(&**left)?; let right = self.build(&**right)?; Ok(Expression::BinaryExpression { diff --git a/query/src/sql/exec/mod.rs b/query/src/sql/exec/mod.rs index ccd1a5136f61f..1efb6afe0c700 100644 --- a/query/src/sql/exec/mod.rs +++ b/query/src/sql/exec/mod.rs @@ -30,6 +30,7 @@ use common_datavalues::ToDataType; use common_datavalues::Vu8; use common_exception::ErrorCode; use common_exception::Result; +use common_functions::scalars::FunctionFactory; use common_planners::find_aggregate_exprs; use common_planners::find_aggregate_exprs_in_expr; use common_planners::Expression; @@ -334,9 +335,13 @@ impl PipelineBuilder { let eb = ExpressionBuilder::create(self.metadata.clone()); let scalars = &filter.predicates; let pred = scalars.iter().cloned().reduce(|acc, v| { + let func = FunctionFactory::instance() + .get("and", &[&acc.data_type(), &v.data_type()]) + .unwrap(); AndExpr { left: Box::new(acc), right: Box::new(v), + return_type: func.return_type(), } .into() }); diff --git a/query/src/sql/optimizer/rule/rewrite/rule_push_down_filter_join.rs b/query/src/sql/optimizer/rule/rewrite/rule_push_down_filter_join.rs index 42acb4ac4d683..f4aaa8b19e98a 100644 --- a/query/src/sql/optimizer/rule/rewrite/rule_push_down_filter_join.rs +++ b/query/src/sql/optimizer/rule/rewrite/rule_push_down_filter_join.rs @@ -66,6 +66,7 @@ impl<'a> Predicate<'a> { op: ComparisonOp::Equal, left, right, + .. }) = scalar { if satisfied_by(left, left_prop) && satisfied_by(right, right_prop) { diff --git a/query/src/sql/planner/binder/aggregate.rs b/query/src/sql/planner/binder/aggregate.rs index f2d581a56c056..e5eaf4929abf1 100644 --- a/query/src/sql/planner/binder/aggregate.rs +++ b/query/src/sql/planner/binder/aggregate.rs @@ -95,17 +95,20 @@ impl<'a> AggregateRewriter<'a> { Scalar::AndExpr(scalar) => Ok(AndExpr { left: Box::new(self.visit(&scalar.left)?), right: Box::new(self.visit(&scalar.right)?), + return_type: scalar.return_type.clone(), } .into()), Scalar::OrExpr(scalar) => Ok(OrExpr { left: Box::new(self.visit(&scalar.left)?), right: Box::new(self.visit(&scalar.right)?), + return_type: scalar.return_type.clone(), } .into()), Scalar::ComparisonExpr(scalar) => Ok(ComparisonExpr { op: scalar.op.clone(), left: Box::new(self.visit(&scalar.left)?), right: Box::new(self.visit(&scalar.right)?), + return_type: scalar.return_type.clone(), } .into()), Scalar::FunctionCall(func) => { diff --git a/query/src/sql/planner/binder/scalar_common.rs b/query/src/sql/planner/binder/scalar_common.rs index b8f8316d21b1a..b98cf5c03fc45 100644 --- a/query/src/sql/planner/binder/scalar_common.rs +++ b/query/src/sql/planner/binder/scalar_common.rs @@ -64,7 +64,7 @@ where F: Fn(&Scalar) -> bool pub fn split_conjunctions(scalar: &Scalar) -> Vec { match scalar { - Scalar::AndExpr(AndExpr { left, right }) => { + Scalar::AndExpr(AndExpr { left, right, .. }) => { vec![split_conjunctions(left), split_conjunctions(right)].concat() } _ => { @@ -75,11 +75,9 @@ pub fn split_conjunctions(scalar: &Scalar) -> Vec { pub fn split_equivalent_predicate(scalar: &Scalar) -> Option<(Scalar, Scalar)> { match scalar { - Scalar::ComparisonExpr(ComparisonExpr { op, left, right }) - if *op == ComparisonOp::Equal => - { - Some((*left.clone(), *right.clone())) - } + Scalar::ComparisonExpr(ComparisonExpr { + op, left, right, .. + }) if *op == ComparisonOp::Equal => Some((*left.clone(), *right.clone())), _ => None, } } diff --git a/query/src/sql/planner/binder/scalar_visitor.rs b/query/src/sql/planner/binder/scalar_visitor.rs index 9c255a0d997b5..0d7977e12c361 100644 --- a/query/src/sql/planner/binder/scalar_visitor.rs +++ b/query/src/sql/planner/binder/scalar_visitor.rs @@ -61,11 +61,11 @@ pub trait ScalarVisitor: Sized { stack.push(RecursionProcessing::Call(&**left)); stack.push(RecursionProcessing::Call(&**right)); } - Scalar::AndExpr(AndExpr { left, right }) => { + Scalar::AndExpr(AndExpr { left, right, .. }) => { stack.push(RecursionProcessing::Call(&**left)); stack.push(RecursionProcessing::Call(&**right)); } - Scalar::OrExpr(OrExpr { left, right }) => { + Scalar::OrExpr(OrExpr { left, right, .. }) => { stack.push(RecursionProcessing::Call(&**left)); stack.push(RecursionProcessing::Call(&**right)); } diff --git a/query/src/sql/planner/binder/subquery.rs b/query/src/sql/planner/binder/subquery.rs index b7a7a0f4514ab..c5b359cef7280 100644 --- a/query/src/sql/planner/binder/subquery.rs +++ b/query/src/sql/planner/binder/subquery.rs @@ -17,6 +17,7 @@ use common_datavalues::DataValue; use common_exception::ErrorCode; use common_exception::Result; use common_functions::aggregates::AggregateFunctionFactory; +use common_functions::scalars::FunctionFactory; use crate::sql::binder::ColumnBinding; use crate::sql::optimizer::ColumnSet; @@ -44,6 +45,7 @@ use crate::sql::plans::ScalarItem; use crate::sql::plans::SubqueryExpr; use crate::sql::plans::SubqueryType; use crate::sql::MetadataRef; +use crate::sql::ScalarExpr; /// Rewrite subquery into `Apply` operator pub struct SubqueryRewriter { @@ -197,6 +199,7 @@ impl SubqueryRewriter { } .into(), ), + return_type: agg_func.return_type()?, }; let eval_scalar = EvalScalar { items: vec![ScalarItem { @@ -262,10 +265,13 @@ impl SubqueryRewriter { Scalar::AndExpr(expr) => { let (left, _result_left) = self.try_rewrite_subquery(&expr.left, s_expr)?; let (right, _result_right) = self.try_rewrite_subquery(&expr.right, s_expr)?; + let func = FunctionFactory::instance() + .get("and", &[&left.data_type(), &right.data_type()])?; Ok(( AndExpr { left: Box::new(left), right: Box::new(right), + return_type: func.return_type(), } .into(), s_expr.clone(), @@ -275,10 +281,13 @@ impl SubqueryRewriter { Scalar::OrExpr(expr) => { let (left, s_expr) = self.try_rewrite_subquery(&expr.left, s_expr)?; let (right, s_expr) = self.try_rewrite_subquery(&expr.right, &s_expr)?; + let func = FunctionFactory::instance() + .get("or", &[&left.data_type(), &right.data_type()])?; Ok(( OrExpr { left: Box::new(left), right: Box::new(right), + return_type: func.return_type(), } .into(), s_expr, @@ -288,11 +297,16 @@ impl SubqueryRewriter { Scalar::ComparisonExpr(expr) => { let (left, s_expr) = self.try_rewrite_subquery(&expr.left, s_expr)?; let (right, s_expr) = self.try_rewrite_subquery(&expr.right, &s_expr)?; + let func = FunctionFactory::instance().get(expr.op.to_func_name(), &[ + &left.data_type(), + &right.data_type(), + ])?; Ok(( ComparisonExpr { op: expr.op.clone(), left: Box::new(left), right: Box::new(right), + return_type: func.return_type(), } .into(), s_expr, diff --git a/query/src/sql/planner/format/display_rel_operator.rs b/query/src/sql/planner/format/display_rel_operator.rs index 86177cdfd035c..19a3245262660 100644 --- a/query/src/sql/planner/format/display_rel_operator.rs +++ b/query/src/sql/planner/format/display_rel_operator.rs @@ -15,6 +15,7 @@ use std::fmt::Display; use common_datavalues::format_data_type_sql; +use common_functions::scalars::FunctionFactory; use itertools::Itertools; use super::FormatTreeNode; @@ -37,6 +38,7 @@ use crate::sql::plans::RelOperator; use crate::sql::plans::Scalar; use crate::sql::plans::Sort; use crate::sql::MetadataRef; +use crate::sql::ScalarExpr; pub struct FormatContext { metadata: MetadataRef, @@ -147,18 +149,26 @@ pub fn format_logical_inner_join( .iter() .zip(op.right_conditions.iter()) .map(|(left, right)| { + let func = FunctionFactory::instance() + .get("=", &[&left.data_type(), &right.data_type()]) + .unwrap(); ComparisonExpr { op: ComparisonOp::Equal, left: Box::new(left.clone()), right: Box::new(right.clone()), + return_type: func.return_type(), } .into() }) .collect(); let pred: Scalar = preds.iter().fold(preds[0].clone(), |prev, next| { + let func = FunctionFactory::instance() + .get("and", &[&prev.data_type(), &next.data_type()]) + .unwrap(); Scalar::AndExpr(AndExpr { left: Box::new(prev), right: Box::new(next.clone()), + return_type: func.return_type(), }) }); write!(f, "LogicalInnerJoin: {}", format_scalar(metadata, &pred)) diff --git a/query/src/sql/planner/plans/scalar.rs b/query/src/sql/planner/plans/scalar.rs index e5db6c9e3725b..282b86a5ac79b 100644 --- a/query/src/sql/planner/plans/scalar.rs +++ b/query/src/sql/planner/plans/scalar.rs @@ -311,11 +311,12 @@ impl ScalarExpr for ConstantExpr { pub struct AndExpr { pub left: Box, pub right: Box, + pub return_type: DataTypeImpl, } impl ScalarExpr for AndExpr { fn data_type(&self) -> DataTypeImpl { - BooleanType::new_impl() + self.return_type.clone() } fn used_columns(&self) -> ColumnSet { @@ -333,11 +334,12 @@ impl ScalarExpr for AndExpr { pub struct OrExpr { pub left: Box, pub right: Box, + pub return_type: DataTypeImpl, } impl ScalarExpr for OrExpr { fn data_type(&self) -> DataTypeImpl { - BooleanType::new_impl() + self.return_type.clone() } fn used_columns(&self) -> ColumnSet { @@ -407,6 +409,7 @@ pub struct ComparisonExpr { pub op: ComparisonOp, pub left: Box, pub right: Box, + pub return_type: DataTypeImpl, } impl ScalarExpr for ComparisonExpr { diff --git a/query/src/sql/planner/semantic/grouping_check.rs b/query/src/sql/planner/semantic/grouping_check.rs index cd6f0bbd12c84..d60589613a725 100644 --- a/query/src/sql/planner/semantic/grouping_check.rs +++ b/query/src/sql/planner/semantic/grouping_check.rs @@ -68,17 +68,20 @@ impl<'a> GroupingChecker<'a> { Scalar::AndExpr(scalar) => Ok(AndExpr { left: Box::new(self.resolve(&scalar.left)?), right: Box::new(self.resolve(&scalar.right)?), + return_type: scalar.return_type.clone(), } .into()), Scalar::OrExpr(scalar) => Ok(OrExpr { left: Box::new(self.resolve(&scalar.left)?), right: Box::new(self.resolve(&scalar.right)?), + return_type: scalar.return_type.clone(), } .into()), Scalar::ComparisonExpr(scalar) => Ok(ComparisonExpr { op: scalar.op.clone(), left: Box::new(self.resolve(&scalar.left)?), right: Box::new(self.resolve(&scalar.right)?), + return_type: scalar.return_type.clone(), } .into()), Scalar::FunctionCall(func) => { diff --git a/query/src/sql/planner/semantic/type_check.rs b/query/src/sql/planner/semantic/type_check.rs index 7cffe3f9b9378..fef41572bf18e 100644 --- a/query/src/sql/planner/semantic/type_check.rs +++ b/query/src/sql/planner/semantic/type_check.rs @@ -193,26 +193,19 @@ impl<'a> TypeChecker<'a> { if !*not { // Rewrite `expr BETWEEN low AND high` // into `expr >= low AND expr <= high` - let (ge_func, _) = self - .resolve_function( - span, - ">=", - &[&**expr, &**low], - Some(BooleanType::new_impl()), - ) + let (ge_func, left_type) = self + .resolve_function(span, ">=", &[&**expr, &**low], None) .await?; - let (le_func, _) = self - .resolve_function( - span, - "<=", - &[&**expr, &**high], - Some(BooleanType::new_impl()), - ) + let (le_func, right_type) = self + .resolve_function(span, "<=", &[&**expr, &**high], None) .await?; + let func = + FunctionFactory::instance().get("and", &[&left_type, &right_type])?; ( AndExpr { left: Box::new(ge_func), right: Box::new(le_func), + return_type: func.return_type(), } .into(), BooleanType::new_impl(), @@ -220,26 +213,18 @@ impl<'a> TypeChecker<'a> { } else { // Rewrite `expr NOT BETWEEN low AND high` // into `expr < low OR expr > high` - let (lt_func, _) = self - .resolve_function( - span, - "<", - &[&**expr, &**low], - Some(BooleanType::new_impl()), - ) + let (lt_func, left_type) = self + .resolve_function(span, "<", &[&**expr, &**low], None) .await?; - let (gt_func, _) = self - .resolve_function( - span, - ">", - &[&**expr, &**high], - Some(BooleanType::new_impl()), - ) + let (gt_func, right_type) = self + .resolve_function(span, ">", &[&**expr, &**high], None) .await?; + let func = FunctionFactory::instance().get("or", &[&left_type, &right_type])?; ( OrExpr { left: Box::new(lt_func), right: Box::new(gt_func), + return_type: func.return_type(), } .into(), BooleanType::new_impl(), @@ -651,30 +636,29 @@ impl<'a> TypeChecker<'a> { let op = ComparisonOp::try_from(op)?; let (left, _) = self.resolve(left, None).await?; let (right, _) = self.resolve(right, None).await?; - let mut data_type = BooleanType::new_impl(); - if left.data_type() == DataTypeImpl::Null(NullType {}) - || right.data_type() == DataTypeImpl::Null(NullType {}) - { - data_type = NullType::new_impl(); - } + let func = FunctionFactory::instance() + .get(op.to_func_name(), &[&left.data_type(), &right.data_type()])?; Ok(( ComparisonExpr { op, left: Box::new(left), right: Box::new(right), + return_type: func.return_type(), } .into(), - data_type, + func.return_type(), )) } BinaryOperator::And => { let (left, _) = self.resolve(left, Some(BooleanType::new_impl())).await?; let (right, _) = self.resolve(right, Some(BooleanType::new_impl())).await?; - + let func = FunctionFactory::instance() + .get("and", &[&left.data_type(), &right.data_type()])?; Ok(( AndExpr { left: Box::new(left), right: Box::new(right), + return_type: func.return_type(), } .into(), BooleanType::new_impl(), @@ -683,11 +667,13 @@ impl<'a> TypeChecker<'a> { BinaryOperator::Or => { let (left, _) = self.resolve(left, Some(BooleanType::new_impl())).await?; let (right, _) = self.resolve(right, Some(BooleanType::new_impl())).await?; - + let func = FunctionFactory::instance() + .get("or", &[&left.data_type(), &right.data_type()])?; Ok(( OrExpr { left: Box::new(left), right: Box::new(right), + return_type: func.return_type(), } .into(), BooleanType::new_impl(), diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result index d9acf91ecb1af..fd8a70d03595f 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result @@ -365,3 +365,14 @@ NULL NULL 6 8 1 2 1 4 3 4 NULL NULL 7 8 NULL NULL +====NULL==== +0 0 0 0 +NULL NULL NULL NULL +NULL NULL NULL 1 +NULL NULL NULL 1 +NULL NULL NULL 1 +NULL NULL NULL NULL +12 1 0 1 +NULL NULL NULL NULL +NULL NULL NULL 1 +NULL NULL NULL 1 diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql index b51b13e0fea55..ae0c69ad5c997 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql @@ -319,4 +319,10 @@ select * from t1 right join t2 on t1.a = t2.c; select * from t1 left join t2 on t1.a = t2.c; drop table t1; drop table t2; + +-- NULL +select '====NULL===='; +create table n( a int null, b int null) ; +insert into n select if (number % 3, null, number), if (number % 2, null, number) from numbers(10); +select a + b, a and b, a - b, a or b from n; set enable_planner_v2 = 0;