Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(planner): support having and scalar expression in group by for new planner #5200

Merged
merged 2 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions query/src/sql/exec/data_schema_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use std::sync::Arc;
use common_datavalues::DataField;
use common_datavalues::DataSchema;
use common_datavalues::DataSchemaRef;
use common_datavalues::DataTypeImpl;
use common_exception::Result;
use common_planners::Expression;

use crate::sql::exec::util::format_field_name;
use crate::sql::plans::PhysicalScan;
Expand Down Expand Up @@ -77,8 +79,11 @@ impl<'a> DataSchemaBuilder<'a> {
for index in plan.columns.iter() {
let column_entry = self.metadata.column(*index);
let field_name = format_field_name(column_entry.name.as_str(), *index);
let field =
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone());
let field = if matches!(column_entry.data_type, DataTypeImpl::Nullable(_)) {
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone())
} else {
DataField::new(field_name.as_str(), column_entry.data_type.clone())
};

fields.push(field);
}
Expand All @@ -91,12 +96,42 @@ impl<'a> DataSchemaBuilder<'a> {
for index in columns {
let column_entry = self.metadata.column(*index);
let field_name = column_entry.name.clone();
let field =
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone());
let field = if matches!(column_entry.data_type, DataTypeImpl::Nullable(_)) {
DataField::new_nullable(field_name.as_str(), column_entry.data_type.clone())
} else {
DataField::new(field_name.as_str(), column_entry.data_type.clone())
};

fields.push(field);
}

Arc::new(DataSchema::new(fields))
}

pub fn build_group_by(
&self,
input_schema: DataSchemaRef,
exprs: &[Expression],
) -> Result<DataSchemaRef> {
if !exprs
.iter()
.any(|expr| !matches!(expr, Expression::Column(_)))
{
return Ok(input_schema);
}
let mut fields = input_schema.fields().clone();
for expr in exprs.iter() {
let expr_name = expr.column_name().clone();
if input_schema.has_field(expr_name.as_str()) {
continue;
}
let field = if expr.nullable(&input_schema)? {
DataField::new_nullable(expr_name.as_str(), expr.to_data_type(&input_schema)?)
} else {
DataField::new(expr_name.as_str(), expr.to_data_type(&input_schema)?)
};
fields.push(field);
}
Ok(Arc::new(DataSchema::new(fields)))
}
}
2 changes: 1 addition & 1 deletion query/src/sql/exec/expression_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl<'a> ExpressionBuilder<'a> {
}

// Transform aggregator expression to column expression
fn normalize_aggr_to_col(&self, expr: Expression) -> Result<Expression> {
pub(crate) fn normalize_aggr_to_col(&self, expr: Expression) -> Result<Expression> {
match expr.clone() {
Expression::BinaryExpression { left, op, right } => {
return Ok(Expression::BinaryExpression {
Expand Down
51 changes: 44 additions & 7 deletions query/src/sql/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use common_datavalues::DataSchema;
use common_datavalues::DataSchemaRef;
use common_exception::ErrorCode;
use common_exception::Result;
use common_planners::find_aggregate_exprs;
use common_planners::find_aggregate_exprs_in_expr;
use common_planners::Expression;
use common_planners::RewriteHelper;
pub use util::decode_field_name;
Expand All @@ -31,6 +33,7 @@ pub use util::format_field_name;
use super::plans::BasePlan;
use crate::pipelines::new::processors::AggregatorParams;
use crate::pipelines::new::processors::AggregatorTransformParams;
use crate::pipelines::new::processors::ExpressionTransform;
use crate::pipelines::new::processors::ProjectionTransform;
use crate::pipelines::new::processors::TransformAggregator;
use crate::pipelines::new::processors::TransformFilter;
Expand Down Expand Up @@ -183,7 +186,16 @@ impl PipelineBuilder {
let output_schema = input_schema.clone();
let eb = ExpressionBuilder::create(&self.metadata);
let scalar = &filter.predicate;
let pred = eb.build(scalar)?;
let mut pred = eb.build(scalar)?;
let no_agg_expression = find_aggregate_exprs_in_expr(&pred).is_empty();
if !no_agg_expression && !filter.is_having {
return Err(ErrorCode::SyntaxException(
"WHERE clause cannot contain aggregate functions",
));
}
if !no_agg_expression && filter.is_having {
pred = eb.normalize_aggr_to_col(pred.clone())?;
}
self.pipeline
.add_transform(|transform_input_port, transform_output_port| {
TransformFilter::try_create(
Expand Down Expand Up @@ -247,16 +259,42 @@ impl PipelineBuilder {
agg_expressions.push(expr);
}

let schema_builder = DataSchemaBuilder::new(&self.metadata);
let partial_data_fields =
RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?;
let partial_schema = schema_builder.build_aggregate(partial_data_fields, &input_schema)?;

let mut group_expressions = Vec::with_capacity(aggregate.group_expr.len());
for scalar in aggregate.group_expr.iter() {
let expr = expr_builder.build(scalar)?;
group_expressions.push(expr);
}

if !find_aggregate_exprs(&group_expressions).is_empty() {
return Err(ErrorCode::SyntaxException(
"Group by clause cannot contain aggregate functions",
));
}

// Process group by with scalar expression, such as `a+1`
// TODO(xudong963): move to aggregate transform
let schema_builder = DataSchemaBuilder::new(&self.metadata);
let pre_input_schema = input_schema.clone();
let input_schema =
schema_builder.build_group_by(input_schema, group_expressions.as_slice())?;
self.pipeline
.add_transform(|transform_input_port, transform_output_port| {
ExpressionTransform::try_create(
transform_input_port,
transform_output_port,
pre_input_schema.clone(),
input_schema.clone(),
group_expressions.clone(),
self.ctx.clone(),
)
})?;

// Get partial schema from agg_expressions
let partial_data_fields =
RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?;
let partial_schema = schema_builder.build_aggregate(partial_data_fields, &input_schema)?;

// Get final schema from agg_expression and group expression
let mut final_exprs = agg_expressions.to_owned();
final_exprs.extend_from_slice(group_expressions.as_slice());
let final_data_fields =
Expand Down Expand Up @@ -288,7 +326,6 @@ impl PipelineBuilder {
&input_schema,
&final_schema,
)?;

self.pipeline
.add_transform(|transform_input_port, transform_output_port| {
TransformAggregator::try_create_final(
Expand Down
19 changes: 16 additions & 3 deletions query/src/sql/planner/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Binder {
};

if let Some(expr) = &stmt.selection {
self.bind_where(expr, &mut input_context)?;
self.bind_where(expr, &mut input_context, false)?;
}

// Output of current `SELECT` statement.
Expand All @@ -90,6 +90,11 @@ impl Binder {
output_context.expression = input_context.expression.clone();
}

if let Some(expr) = &stmt.having {
self.bind_where(expr, &mut input_context, true)?;
output_context.expression = input_context.expression.clone();
}

self.bind_projection(&mut output_context)?;

Ok(output_context)
Expand Down Expand Up @@ -200,10 +205,18 @@ impl Binder {
Ok(bind_context)
}

pub(super) fn bind_where(&mut self, expr: &Expr, bind_context: &mut BindContext) -> Result<()> {
pub(super) fn bind_where(
&mut self,
expr: &Expr,
bind_context: &mut BindContext,
is_having: bool,
) -> Result<()> {
let scalar_binder = ScalarBinder::new(bind_context);
let (scalar, _) = scalar_binder.bind_expr(expr)?;
let filter_plan = FilterPlan { predicate: scalar };
let filter_plan = FilterPlan {
predicate: scalar,
is_having,
};
let new_expr =
SExpr::create_unary(filter_plan.into(), bind_context.expression.clone().unwrap());
bind_context.expression = Some(new_expr);
Expand Down
2 changes: 2 additions & 0 deletions query/src/sql/planner/plans/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::sql::plans::Scalar;
pub struct FilterPlan {
// TODO: split predicate into conjunctions
pub predicate: Scalar,
// True if the plan represents having, else the plan represents where
pub is_having: bool,
}

impl BasePlan for FilterPlan {
Expand Down
4 changes: 4 additions & 0 deletions tests/suites/0_stateless/20+_others/20_0001_planner_v2.result
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@
1
1
3
3
2
2
4
3 changes: 3 additions & 0 deletions tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ select count() from t group by a;
select count(1) from t;
select count(1) from t group by a;
select count(*) from t;
select sum(a) from t group by a having sum(a) > 1;
select sum(a+1) from t group by a+1 having sum(a+1) = 2;
select sum(a+1) from t group by a+1, b having sum(a+1) > 3;
drop table t;
set enable_planner_v2 = 0;