diff --git a/query/src/sql/exec/data_schema_builder.rs b/query/src/sql/exec/data_schema_builder.rs index 76b121876be5..28a024c37b7b 100644 --- a/query/src/sql/exec/data_schema_builder.rs +++ b/query/src/sql/exec/data_schema_builder.rs @@ -17,7 +17,9 @@ use std::sync::Arc; use common_datavalues::DataField; use common_datavalues::DataSchema; use common_datavalues::DataSchemaRef; +use common_datavalues::DataTypeImpl; use common_exception::Result; +use common_planners::Expression; use crate::sql::exec::util::format_field_name; use crate::sql::plans::PhysicalScan; @@ -77,8 +79,11 @@ impl<'a> DataSchemaBuilder<'a> { for index in plan.columns.iter() { let column_entry = self.metadata.column(*index); let field_name = format_field_name(column_entry.name.as_str(), *index); - let field = - DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone()); + let field = if matches!(column_entry.data_type, DataTypeImpl::Nullable(_)) { + DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone()) + } else { + DataField::new(field_name.as_str(), column_entry.data_type.clone()) + }; fields.push(field); } @@ -91,12 +96,42 @@ impl<'a> DataSchemaBuilder<'a> { for index in columns { let column_entry = self.metadata.column(*index); let field_name = column_entry.name.clone(); - let field = - DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone()); + let field = if matches!(column_entry.data_type, DataTypeImpl::Nullable(_)) { + DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone()) + } else { + DataField::new(field_name.as_str(), column_entry.data_type.clone()) + }; fields.push(field); } Arc::new(DataSchema::new(fields)) } + + pub fn build_group_by( + &self, + input_schema: DataSchemaRef, + exprs: &[Expression], + ) -> Result { + if !exprs + .iter() + .any(|expr| !matches!(expr, Expression::Column(_))) + { + return Ok(input_schema); + } + let mut fields = input_schema.fields().clone(); + for expr in exprs.iter() { + let expr_name = expr.column_name().clone(); + if input_schema.has_field(expr_name.as_str()) { + continue; + } + let field = if expr.nullable(&input_schema)? { + DataField::new_nullable(expr_name.as_str(), expr.to_data_type(&input_schema)?) + } else { + DataField::new(expr_name.as_str(), expr.to_data_type(&input_schema)?) + }; + fields.push(field); + } + Ok(Arc::new(DataSchema::new(fields))) + } } diff --git a/query/src/sql/exec/expression_builder.rs b/query/src/sql/exec/expression_builder.rs index 7066c6fbad8f..bfeaadbdeca1 100644 --- a/query/src/sql/exec/expression_builder.rs +++ b/query/src/sql/exec/expression_builder.rs @@ -158,7 +158,7 @@ impl<'a> ExpressionBuilder<'a> { } // Transform aggregator expression to column expression - fn normalize_aggr_to_col(&self, expr: Expression) -> Result { + pub(crate) fn normalize_aggr_to_col(&self, expr: Expression) -> Result { match expr.clone() { Expression::BinaryExpression { left, op, right } => { return Ok(Expression::BinaryExpression { diff --git a/query/src/sql/exec/mod.rs b/query/src/sql/exec/mod.rs index 632bdf44e3d8..f1048322f8df 100644 --- a/query/src/sql/exec/mod.rs +++ b/query/src/sql/exec/mod.rs @@ -23,6 +23,8 @@ use common_datavalues::DataSchema; use common_datavalues::DataSchemaRef; use common_exception::ErrorCode; use common_exception::Result; +use common_planners::find_aggregate_exprs; +use common_planners::find_aggregate_exprs_in_expr; use common_planners::Expression; use common_planners::RewriteHelper; pub use util::decode_field_name; @@ -31,6 +33,7 @@ pub use util::format_field_name; use super::plans::BasePlan; use crate::pipelines::new::processors::AggregatorParams; use crate::pipelines::new::processors::AggregatorTransformParams; +use crate::pipelines::new::processors::ExpressionTransform; use crate::pipelines::new::processors::ProjectionTransform; use crate::pipelines::new::processors::TransformAggregator; use crate::pipelines::new::processors::TransformFilter; @@ -183,7 +186,16 @@ impl PipelineBuilder { let output_schema = input_schema.clone(); let eb = ExpressionBuilder::create(&self.metadata); let scalar = &filter.predicate; - let pred = eb.build(scalar)?; + let mut pred = eb.build(scalar)?; + let no_agg_expression = find_aggregate_exprs_in_expr(&pred).is_empty(); + if !no_agg_expression && !filter.is_having { + return Err(ErrorCode::SyntaxException( + "WHERE clause cannot contain aggregate functions", + )); + } + if !no_agg_expression && filter.is_having { + pred = eb.normalize_aggr_to_col(pred.clone())?; + } self.pipeline .add_transform(|transform_input_port, transform_output_port| { TransformFilter::try_create( @@ -247,16 +259,42 @@ impl PipelineBuilder { agg_expressions.push(expr); } - let schema_builder = DataSchemaBuilder::new(&self.metadata); - let partial_data_fields = - RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?; - let partial_schema = schema_builder.build_aggregate(partial_data_fields, &input_schema)?; - let mut group_expressions = Vec::with_capacity(aggregate.group_expr.len()); for scalar in aggregate.group_expr.iter() { let expr = expr_builder.build(scalar)?; group_expressions.push(expr); } + + if !find_aggregate_exprs(&group_expressions).is_empty() { + return Err(ErrorCode::SyntaxException( + "Group by clause cannot contain aggregate functions", + )); + } + + // Process group by with scalar expression, such as `a+1` + // TODO(xudong963): move to aggregate transform + let schema_builder = DataSchemaBuilder::new(&self.metadata); + let pre_input_schema = input_schema.clone(); + let input_schema = + schema_builder.build_group_by(input_schema, group_expressions.as_slice())?; + self.pipeline + .add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + group_expressions.clone(), + self.ctx.clone(), + ) + })?; + + // Get partial schema from agg_expressions + let partial_data_fields = + RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?; + let partial_schema = schema_builder.build_aggregate(partial_data_fields, &input_schema)?; + + // Get final schema from agg_expression and group expression let mut final_exprs = agg_expressions.to_owned(); final_exprs.extend_from_slice(group_expressions.as_slice()); let final_data_fields = @@ -288,7 +326,6 @@ impl PipelineBuilder { &input_schema, &final_schema, )?; - self.pipeline .add_transform(|transform_input_port, transform_output_port| { TransformAggregator::try_create_final( diff --git a/query/src/sql/planner/binder/select.rs b/query/src/sql/planner/binder/select.rs index 55274e5b6063..cdd431690e3b 100644 --- a/query/src/sql/planner/binder/select.rs +++ b/query/src/sql/planner/binder/select.rs @@ -76,7 +76,7 @@ impl Binder { }; if let Some(expr) = &stmt.selection { - self.bind_where(expr, &mut input_context)?; + self.bind_where(expr, &mut input_context, false)?; } // Output of current `SELECT` statement. @@ -90,6 +90,11 @@ impl Binder { output_context.expression = input_context.expression.clone(); } + if let Some(expr) = &stmt.having { + self.bind_where(expr, &mut input_context, true)?; + output_context.expression = input_context.expression.clone(); + } + self.bind_projection(&mut output_context)?; Ok(output_context) @@ -200,10 +205,18 @@ impl Binder { Ok(bind_context) } - pub(super) fn bind_where(&mut self, expr: &Expr, bind_context: &mut BindContext) -> Result<()> { + pub(super) fn bind_where( + &mut self, + expr: &Expr, + bind_context: &mut BindContext, + is_having: bool, + ) -> Result<()> { let scalar_binder = ScalarBinder::new(bind_context); let (scalar, _) = scalar_binder.bind_expr(expr)?; - let filter_plan = FilterPlan { predicate: scalar }; + let filter_plan = FilterPlan { + predicate: scalar, + is_having, + }; let new_expr = SExpr::create_unary(filter_plan.into(), bind_context.expression.clone().unwrap()); bind_context.expression = Some(new_expr); diff --git a/query/src/sql/planner/plans/filter.rs b/query/src/sql/planner/plans/filter.rs index d5f083d327c7..19f24e416a37 100644 --- a/query/src/sql/planner/plans/filter.rs +++ b/query/src/sql/planner/plans/filter.rs @@ -27,6 +27,8 @@ use crate::sql::plans::Scalar; pub struct FilterPlan { // TODO: split predicate into conjunctions pub predicate: Scalar, + // True if the plan represents having, else the plan represents where + pub is_having: bool, } impl BasePlan for FilterPlan { diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result index b13e4fc8fedd..0de278436474 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result @@ -34,3 +34,7 @@ 1 1 3 +3 +2 +2 +4 diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql index d573f880d27d..660248a9bd1c 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql @@ -30,5 +30,8 @@ select count() from t group by a; select count(1) from t; select count(1) from t group by a; select count(*) from t; +select sum(a) from t group by a having sum(a) > 1; +select sum(a+1) from t group by a+1 having sum(a+1) = 2; +select sum(a+1) from t group by a+1, b having sum(a+1) > 3; drop table t; set enable_planner_v2 = 0; \ No newline at end of file