Skip to content

Commit

Permalink
Merge pull request #5200 from xudong963/new_having
Browse files Browse the repository at this point in the history
feat(planner): support having and scalar expression in group by for new planner
  • Loading branch information
BohuTANG authored May 6, 2022
2 parents 2da916d + 4b38b3e commit f2ef23b
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 15 deletions.
43 changes: 39 additions & 4 deletions query/src/sql/exec/data_schema_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use std::sync::Arc;
use common_datavalues::DataField;
use common_datavalues::DataSchema;
use common_datavalues::DataSchemaRef;
use common_datavalues::DataTypeImpl;
use common_exception::Result;
use common_planners::Expression;

use crate::sql::exec::util::format_field_name;
use crate::sql::plans::PhysicalScan;
Expand Down Expand Up @@ -77,8 +79,11 @@ impl<'a> DataSchemaBuilder<'a> {
for index in plan.columns.iter() {
let column_entry = self.metadata.column(*index);
let field_name = format_field_name(column_entry.name.as_str(), *index);
let field =
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone());
let field = if matches!(column_entry.data_type, DataTypeImpl::Nullable(_)) {
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone())
} else {
DataField::new(field_name.as_str(), column_entry.data_type.clone())
};

fields.push(field);
}
Expand All @@ -91,12 +96,42 @@ impl<'a> DataSchemaBuilder<'a> {
for index in columns {
let column_entry = self.metadata.column(*index);
let field_name = column_entry.name.clone();
let field =
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone());
let field = if matches!(column_entry.data_type, DataTypeImpl::Nullable(_)) {
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone())
} else {
DataField::new(field_name.as_str(), column_entry.data_type.clone())
};

fields.push(field);
}

Arc::new(DataSchema::new(fields))
}

pub fn build_group_by(
&self,
input_schema: DataSchemaRef,
exprs: &[Expression],
) -> Result<DataSchemaRef> {
if !exprs
.iter()
.any(|expr| !matches!(expr, Expression::Column(_)))
{
return Ok(input_schema);
}
let mut fields = input_schema.fields().clone();
for expr in exprs.iter() {
let expr_name = expr.column_name().clone();
if input_schema.has_field(expr_name.as_str()) {
continue;
}
let field = if expr.nullable(&input_schema)? {
DataField::new_nullable(expr_name.as_str(), expr.to_data_type(&input_schema)?)
} else {
DataField::new(expr_name.as_str(), expr.to_data_type(&input_schema)?)
};
fields.push(field);
}
Ok(Arc::new(DataSchema::new(fields)))
}
}
2 changes: 1 addition & 1 deletion query/src/sql/exec/expression_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl<'a> ExpressionBuilder<'a> {
}

// Transform aggregator expression to column expression
fn normalize_aggr_to_col(&self, expr: Expression) -> Result<Expression> {
pub(crate) fn normalize_aggr_to_col(&self, expr: Expression) -> Result<Expression> {
match expr.clone() {
Expression::BinaryExpression { left, op, right } => {
return Ok(Expression::BinaryExpression {
Expand Down
51 changes: 44 additions & 7 deletions query/src/sql/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use common_datavalues::DataSchema;
use common_datavalues::DataSchemaRef;
use common_exception::ErrorCode;
use common_exception::Result;
use common_planners::find_aggregate_exprs;
use common_planners::find_aggregate_exprs_in_expr;
use common_planners::Expression;
use common_planners::RewriteHelper;
pub use util::decode_field_name;
Expand All @@ -31,6 +33,7 @@ pub use util::format_field_name;
use super::plans::BasePlan;
use crate::pipelines::new::processors::AggregatorParams;
use crate::pipelines::new::processors::AggregatorTransformParams;
use crate::pipelines::new::processors::ExpressionTransform;
use crate::pipelines::new::processors::ProjectionTransform;
use crate::pipelines::new::processors::TransformAggregator;
use crate::pipelines::new::processors::TransformFilter;
Expand Down Expand Up @@ -183,7 +186,16 @@ impl PipelineBuilder {
let output_schema = input_schema.clone();
let eb = ExpressionBuilder::create(&self.metadata);
let scalar = &filter.predicate;
let pred = eb.build(scalar)?;
let mut pred = eb.build(scalar)?;
let no_agg_expression = find_aggregate_exprs_in_expr(&pred).is_empty();
if !no_agg_expression && !filter.is_having {
return Err(ErrorCode::SyntaxException(
"WHERE clause cannot contain aggregate functions",
));
}
if !no_agg_expression && filter.is_having {
pred = eb.normalize_aggr_to_col(pred.clone())?;
}
self.pipeline
.add_transform(|transform_input_port, transform_output_port| {
TransformFilter::try_create(
Expand Down Expand Up @@ -247,16 +259,42 @@ impl PipelineBuilder {
agg_expressions.push(expr);
}

let schema_builder = DataSchemaBuilder::new(&self.metadata);
let partial_data_fields =
RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?;
let partial_schema = schema_builder.build_aggregate(partial_data_fields, &input_schema)?;

let mut group_expressions = Vec::with_capacity(aggregate.group_expr.len());
for scalar in aggregate.group_expr.iter() {
let expr = expr_builder.build(scalar)?;
group_expressions.push(expr);
}

if !find_aggregate_exprs(&group_expressions).is_empty() {
return Err(ErrorCode::SyntaxException(
"Group by clause cannot contain aggregate functions",
));
}

// Process group by with scalar expression, such as `a+1`
// TODO(xudong963): move to aggregate transform
let schema_builder = DataSchemaBuilder::new(&self.metadata);
let pre_input_schema = input_schema.clone();
let input_schema =
schema_builder.build_group_by(input_schema, group_expressions.as_slice())?;
self.pipeline
.add_transform(|transform_input_port, transform_output_port| {
ExpressionTransform::try_create(
transform_input_port,
transform_output_port,
pre_input_schema.clone(),
input_schema.clone(),
group_expressions.clone(),
self.ctx.clone(),
)
})?;

// Get partial schema from agg_expressions
let partial_data_fields =
RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?;
let partial_schema = schema_builder.build_aggregate(partial_data_fields, &input_schema)?;

// Get final schema from agg_expression and group expression
let mut final_exprs = agg_expressions.to_owned();
final_exprs.extend_from_slice(group_expressions.as_slice());
let final_data_fields =
Expand Down Expand Up @@ -288,7 +326,6 @@ impl PipelineBuilder {
&input_schema,
&final_schema,
)?;

self.pipeline
.add_transform(|transform_input_port, transform_output_port| {
TransformAggregator::try_create_final(
Expand Down
19 changes: 16 additions & 3 deletions query/src/sql/planner/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Binder {
};

if let Some(expr) = &stmt.selection {
self.bind_where(expr, &mut input_context)?;
self.bind_where(expr, &mut input_context, false)?;
}

// Output of current `SELECT` statement.
Expand All @@ -90,6 +90,11 @@ impl Binder {
output_context.expression = input_context.expression.clone();
}

if let Some(expr) = &stmt.having {
self.bind_where(expr, &mut input_context, true)?;
output_context.expression = input_context.expression.clone();
}

self.bind_projection(&mut output_context)?;

Ok(output_context)
Expand Down Expand Up @@ -200,10 +205,18 @@ impl Binder {
Ok(bind_context)
}

pub(super) fn bind_where(&mut self, expr: &Expr, bind_context: &mut BindContext) -> Result<()> {
pub(super) fn bind_where(
&mut self,
expr: &Expr,
bind_context: &mut BindContext,
is_having: bool,
) -> Result<()> {
let scalar_binder = ScalarBinder::new(bind_context);
let (scalar, _) = scalar_binder.bind_expr(expr)?;
let filter_plan = FilterPlan { predicate: scalar };
let filter_plan = FilterPlan {
predicate: scalar,
is_having,
};
let new_expr =
SExpr::create_unary(filter_plan.into(), bind_context.expression.clone().unwrap());
bind_context.expression = Some(new_expr);
Expand Down
2 changes: 2 additions & 0 deletions query/src/sql/planner/plans/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::sql::plans::Scalar;
pub struct FilterPlan {
// TODO: split predicate into conjunctions
pub predicate: Scalar,
// True if the plan represents having, else the plan represents where
pub is_having: bool,
}

impl BasePlan for FilterPlan {
Expand Down
4 changes: 4 additions & 0 deletions tests/suites/0_stateless/20+_others/20_0001_planner_v2.result
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@
1
1
3
3
2
2
4
3 changes: 3 additions & 0 deletions tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ select count() from t group by a;
select count(1) from t;
select count(1) from t group by a;
select count(*) from t;
select sum(a) from t group by a having sum(a) > 1;
select sum(a+1) from t group by a+1 having sum(a+1) = 2;
select sum(a+1) from t group by a+1, b having sum(a+1) > 3;
drop table t;
set enable_planner_v2 = 0;

0 comments on commit f2ef23b

Please sign in to comment.