diff --git a/Cargo.lock b/Cargo.lock index a49d0f514c3e..6aa4bcb6eea0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2021,6 +2021,7 @@ dependencies = [ "criterion", "dyn-clone", "enum_dispatch", + "enum_extract", "futures", "headers", "http", @@ -2050,6 +2051,7 @@ dependencies = [ "regex", "reqwest", "rsa", + "segment-tree", "semver", "serde", "serde-bridge", @@ -2320,6 +2322,12 @@ dependencies = [ "syn", ] +[[package]] +name = "enum_extract" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11578f9e8496eeb626c549e1b340a57e479840ff7ae13fe12d4dbd97c18779d1" + [[package]] name = "enumflags2" version = "0.7.5" @@ -5937,6 +5945,12 @@ dependencies = [ "libc", ] +[[package]] +name = "segment-tree" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f7dbd0d32cabaa6c7c3286d756268247538d613b621227bfe59237d7bbb271a" + [[package]] name = "semver" version = "1.0.9" diff --git a/common/ast/src/udfs/udf_expr_visitor.rs b/common/ast/src/udfs/udf_expr_visitor.rs index 1ed2b0ec0290..24b1faacc95d 100644 --- a/common/ast/src/udfs/udf_expr_visitor.rs +++ b/common/ast/src/udfs/udf_expr_visitor.rs @@ -163,6 +163,15 @@ pub trait UDFExprVisitor: Sized + Send { }; } + if let Some(over) = &function.over { + for partition_by in &over.partition_by { + UDFExprTraverser::accept(partition_by, self)?; + } + for order_by in &over.order_by { + UDFExprTraverser::accept(&order_by.expr, self)?; + } + } + Ok(()) } diff --git a/common/ast/tests/it/parser.rs b/common/ast/tests/it/parser.rs index 56f4c9d52f69..4369adfec56f 100644 --- a/common/ast/tests/it/parser.rs +++ b/common/ast/tests/it/parser.rs @@ -151,7 +151,7 @@ fn test_statements_in_legacy_suites() { // TODO(andylokandy): support all cases eventually // Remove currently unimplemented cases let file_str = regex::Regex::new( - "(?i).*(SLAVE|MASTER|COMMIT|START|ROLLBACK|FIELDS|GRANT|COPY|ROLE|STAGE|ENGINES|UNDROP).*\n", + "(?i).*(SLAVE|MASTER|COMMIT|START|ROLLBACK|FIELDS|GRANT|COPY|ROLE|STAGE|ENGINES|UNDROP|OVER).*\n", ) .unwrap() .replace_all(&file_str, "") diff --git a/common/functions/src/lib.rs b/common/functions/src/lib.rs index 9340791ea80c..a17897503d1a 100644 --- a/common/functions/src/lib.rs +++ b/common/functions/src/lib.rs @@ -19,6 +19,7 @@ pub mod aggregates; pub mod rdoc; pub mod scalars; +pub mod window; use aggregates::AggregateFunctionFactory; use scalars::FunctionFactory; diff --git a/common/functions/src/window/function.rs b/common/functions/src/window/function.rs new file mode 100644 index 000000000000..d614d587efa3 --- /dev/null +++ b/common/functions/src/window/function.rs @@ -0,0 +1,18 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub enum WindowFunction { + AggregateFunction, + BuiltInFunction, +} diff --git a/common/functions/src/window/mod.rs b/common/functions/src/window/mod.rs new file mode 100644 index 000000000000..ecb24c836d55 --- /dev/null +++ b/common/functions/src/window/mod.rs @@ -0,0 +1,19 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod function; +mod window_frame; + +pub use function::*; +pub use window_frame::*; diff --git a/common/functions/src/window/window_frame.rs b/common/functions/src/window/window_frame.rs new file mode 100644 index 000000000000..e4a0e6e1bf09 --- /dev/null +++ b/common/functions/src/window/window_frame.rs @@ -0,0 +1,189 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Code in this file is mainly copied from apache/arrow-datafusion +// Original project code: https://github.com/apache/arrow-datafusion/blob/a6e93a10ab2659500eb4f838b7b53f138e545be3/datafusion/expr/src/window_frame.rs#L39 +// PR: https://github.com/datafuselabs/databend/pull/5401 + +use std::cmp::Ordering; +use std::fmt; +use std::hash::Hash; +use std::hash::Hasher; + +use common_exception::ErrorCode; +use sqlparser::ast; + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash, serde::Serialize, serde::Deserialize, +)] +pub struct WindowFrame { + pub units: WindowFrameUnits, + pub start_bound: WindowFrameBound, + pub end_bound: WindowFrameBound, +} + +impl TryFrom for WindowFrame { + type Error = ErrorCode; + + fn try_from(value: ast::WindowFrame) -> Result { + let start_bound = value.start_bound.into(); + let end_bound = value + .end_bound + .map(WindowFrameBound::from) + .unwrap_or(WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = start_bound { + Err(ErrorCode::LogicalError( + "Invalid window frame: start bound cannot be unbounded following".to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = end_bound { + Err(ErrorCode::LogicalError( + "Invalid window frame: end bound cannot be unbounded preceding".to_owned(), + )) + } else if start_bound > end_bound { + Err(ErrorCode::LogicalError(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + let units = value.units.into(); + Ok(Self { + units, + start_bound, + end_bound, + }) + } + } +} + +impl Default for WindowFrame { + fn default() -> Self { + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::CurrentRow, + } + } +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash, serde::Serialize, serde::Deserialize, +)] +pub enum WindowFrameUnits { + Range, + Rows, +} + +impl fmt::Display for WindowFrameUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + WindowFrameUnits::Range => "RANGE", + WindowFrameUnits::Rows => "ROWS", + }) + } +} + +impl From for WindowFrameUnits { + fn from(value: ast::WindowFrameUnits) -> Self { + match value { + ast::WindowFrameUnits::Range => Self::Range, + ast::WindowFrameUnits::Rows => Self::Rows, + _ => unimplemented!(), + } + } +} + +#[derive(Debug, Clone, Copy, Eq, serde::Serialize, serde::Deserialize)] +pub enum WindowFrameBound { + Preceding(Option), + CurrentRow, + Following(Option), +} + +impl WindowFrameBound { + fn get_rank(&self) -> (u8, u64) { + match self { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } + } +} + +impl From for WindowFrameBound { + fn from(value: ast::WindowFrameBound) -> Self { + match value { + ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), + ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::CurrentRow => Self::CurrentRow, + } + } +} + +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), + WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), + WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), + WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + } + } +} + +impl PartialEq for WindowFrameBound { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for WindowFrameBound { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WindowFrameBound { + fn cmp(&self, other: &Self) -> Ordering { + self.get_rank().cmp(&other.get_rank()) + } +} + +impl Hash for WindowFrameBound { + fn hash(&self, state: &mut H) { + self.get_rank().hash(state) + } +} diff --git a/common/planners/src/lib.rs b/common/planners/src/lib.rs index 967096a925e7..c68928236ff2 100644 --- a/common/planners/src/lib.rs +++ b/common/planners/src/lib.rs @@ -103,6 +103,7 @@ mod plan_user_udf_drop; mod plan_view_alter; mod plan_view_create; mod plan_view_drop; +mod plan_window_func; pub use plan_aggregator_final::AggregatorFinalPlan; pub use plan_aggregator_partial::AggregatorPartialPlan; @@ -135,6 +136,8 @@ pub use plan_expression_common::extract_aliases; pub use plan_expression_common::find_aggregate_exprs; pub use plan_expression_common::find_aggregate_exprs_in_expr; pub use plan_expression_common::find_columns_not_satisfy_exprs; +pub use plan_expression_common::find_window_exprs; +pub use plan_expression_common::find_window_exprs_in_expr; pub use plan_expression_common::rebase_expr; pub use plan_expression_common::rebase_expr_from_input; pub use plan_expression_common::resolve_aliases_to_exprs; @@ -233,3 +236,4 @@ pub use plan_user_udf_drop::DropUserUDFPlan; pub use plan_view_alter::AlterViewPlan; pub use plan_view_create::CreateViewPlan; pub use plan_view_drop::DropViewPlan; +pub use plan_window_func::WindowFuncPlan; diff --git a/common/planners/src/plan_expression.rs b/common/planners/src/plan_expression.rs index 275266cf9252..37f181c3c78a 100644 --- a/common/planners/src/plan_expression.rs +++ b/common/planners/src/plan_expression.rs @@ -21,6 +21,7 @@ use common_exception::ErrorCode; use common_exception::Result; use common_functions::aggregates::AggregateFunctionFactory; use common_functions::aggregates::AggregateFunctionRef; +use common_functions::window::WindowFrame; use once_cell::sync::Lazy; use crate::plan_expression_common::ExpressionDataTypeVisitor; @@ -94,6 +95,22 @@ pub enum Expression { args: Vec, }, + /// WindowFunction + WindowFunction { + /// operation performed + op: String, + /// params + params: Vec, + /// arguments + args: Vec, + /// partition by + partition_by: Vec, + /// order by + order_by: Vec, + /// window frame + window_frame: Option, + }, + /// A sort expression, that can be used to sort values. Sort { /// The expression to sort on @@ -421,6 +438,53 @@ impl fmt::Debug for Expression { Ok(()) } + Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + window_frame, + } => { + let args_column_name = args.iter().map(Expression::column_name).collect::>(); + let params_name = params + .iter() + .map(|v| DataValue::custom_display(v, true)) + .collect::>(); + + if params.is_empty() { + write!(f, "{}", op)?; + } else { + write!(f, "{}({})", op, params_name.join(", "))?; + } + + write!(f, "({})", args_column_name.join(","))?; + + write!(f, " OVER(")?; + if !partition_by.is_empty() { + write!(f, "PARTITION BY {:?}", partition_by)?; + } + if !order_by.is_empty() { + if !partition_by.is_empty() { + write!(f, " ")?; + } + write!(f, "ORDER BY {:?}", order_by)?; + } + if let Some(window_frame) = window_frame { + if !partition_by.is_empty() || !order_by.is_empty() { + write!(f, " ")?; + } + write!( + f, + "{} BETWEEN {} AND {}", + window_frame.units, window_frame.start_bound, window_frame.end_bound + )?; + } + write!(f, ")")?; + + Ok(()) + } + Expression::Sort { expr, .. } => write!(f, "{:?}", expr), Expression::Wildcard => write!(f, "*"), Expression::Cast { diff --git a/common/planners/src/plan_expression_chain.rs b/common/planners/src/plan_expression_chain.rs index 8640d74a217c..5c19f789a926 100644 --- a/common/planners/src/plan_expression_chain.rs +++ b/common/planners/src/plan_expression_chain.rs @@ -192,7 +192,11 @@ impl ExpressionChain { "Action must be a non-aggregated function.", )); } + + Expression::WindowFunction { .. } => {} + Expression::Wildcard | Expression::Sort { .. } => {} + Expression::Cast { expr: sub_expr, data_type, diff --git a/common/planners/src/plan_expression_common.rs b/common/planners/src/plan_expression_common.rs index 2ef5efb34053..a612ed67dec4 100644 --- a/common/planners/src/plan_expression_common.rs +++ b/common/planners/src/plan_expression_common.rs @@ -18,6 +18,7 @@ use std::collections::HashSet; use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; +use common_functions::aggregates::AggregateFunctionFactory; use common_functions::scalars::FunctionFactory; use crate::validate_function_arg; @@ -52,9 +53,21 @@ pub fn find_aggregate_exprs_in_expr(expr: &Expression) -> Vec { }) } +/// Collect all deeply nested `Expression::WindowFunction`. +pub fn find_window_exprs(exprs: &[Expression]) -> Vec { + find_exprs_in_exprs(exprs, &|nested_expr| { + matches!(nested_expr, Expression::WindowFunction { .. }) + }) +} + +pub fn find_window_exprs_in_expr(expr: &Expression) -> Vec { + find_exprs_in_expr(expr, &|nest_exprs| { + matches!(nest_exprs, Expression::WindowFunction { .. }) + }) +} + /// Collect all arguments from aggregation function and append to this exprs /// [ColumnExpr(b), Aggr(sum(a, b))] ---> [ColumnExpr(b), ColumnExpr(a)] - pub fn expand_aggregate_arg_exprs(exprs: &[Expression]) -> Vec { let mut res = vec![]; for expr in exprs { @@ -304,6 +317,31 @@ where F: Fn(&Expression) -> Result> { .collect::>>()?, }), + Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + window_frame, + } => Ok(Expression::WindowFunction { + op: op.clone(), + params: params.clone(), + args: args + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + partition_by: partition_by + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + order_by: order_by + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + window_frame: window_frame.to_owned(), + }), + Expression::Sort { expr: nested_expr, asc, @@ -485,6 +523,37 @@ impl ExpressionVisitor for ExpressionDataTypeVisitor { self.stack.push(return_type); Ok(self) } + Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + .. + } => { + for _ in 0..partition_by.len() + order_by.len() { + self.stack.remove(0); + } + + if !AggregateFunctionFactory::instance().check(op) { + return Err(ErrorCode::LogicalError( + "not yet support non aggr window function", + )); + } + + let mut fields = Vec::with_capacity(args.len()); + for arg in args.iter() { + let arg_type = self.stack.remove(0); + fields.push(DataField::new(&arg.column_name(), arg_type)); + } + + let aggregate_window_function = + AggregateFunctionFactory::instance().get(op, params.clone(), fields); + let return_type = aggregate_window_function.unwrap().return_type().unwrap(); + + self.stack.push(return_type); + Ok(self) + } Expression::Cast { data_type, .. } => { let inner_type = match self.stack.pop() { None => Err(ErrorCode::LogicalError( diff --git a/common/planners/src/plan_expression_rewriter.rs b/common/planners/src/plan_expression_rewriter.rs index 4c9565f584ed..17066997abdc 100644 --- a/common/planners/src/plan_expression_rewriter.rs +++ b/common/planners/src/plan_expression_rewriter.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; +use common_functions::window::WindowFrame; use crate::Expression; use crate::ExpressionVisitor; @@ -124,6 +125,25 @@ pub trait ExpressionRewriter: Sized { }) } + fn mutate_window_function( + &mut self, + op: String, + params: Vec, + args: Vec, + partition_by: Vec, + order_by: Vec, + window_frame: Option, + ) -> Result { + Ok(Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + window_frame, + }) + } + fn mutate_cast( &mut self, typ: &DataTypeImpl, @@ -300,6 +320,61 @@ impl ExpressionVisitor for ExpressionRewriteVisitor { self.stack.push(new_expr); Ok(self) } + Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + window_frame, + } => { + let mut new_args = Vec::with_capacity(args.len()); + for i in 0..args.len() { + match self.stack.pop() { + Some(expr) => new_args.push(expr), + None => { + return Err(ErrorCode::LogicalError(format!( + "WindowFunction expects {} partition by arguments, actual {}", + partition_by.len(), + i + ))) + } + } + } + + let mut new_partition_by = Vec::with_capacity(partition_by.len()); + let mut new_order_by = Vec::with_capacity(order_by.len()); + for i in 0..partition_by.len() + order_by.len() { + match self.stack.pop() { + Some(expr) => { + if i < partition_by.len() { + new_partition_by.push(expr); + } else { + new_order_by.push(expr); + } + } + None => { + return Err(ErrorCode::LogicalError(format!( + "WindowFunction expects {} partition by arguments, actual {}", + partition_by.len(), + i + ))) + } + } + } + + let new_expr = self.inner.mutate_window_function( + op.clone(), + params.to_owned(), + new_args, + new_partition_by, + new_order_by, + window_frame.to_owned(), + )?; + + self.stack.push(new_expr); + Ok(self) + } Expression::Cast { data_type, pg_style, diff --git a/common/planners/src/plan_expression_visitor.rs b/common/planners/src/plan_expression_visitor.rs index 345694e0a8ef..2c73df11867c 100644 --- a/common/planners/src/plan_expression_visitor.rs +++ b/common/planners/src/plan_expression_visitor.rs @@ -66,6 +66,22 @@ pub trait ExpressionVisitor: Sized { stack.push(RecursionProcessing::Call(arg)); } } + Expression::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + for arg_expr in args { + stack.push(RecursionProcessing::Call(arg_expr)); + } + for part_by_expr in partition_by { + stack.push(RecursionProcessing::Call(part_by_expr)); + } + for order_by_expr in order_by { + stack.push(RecursionProcessing::Call(order_by_expr)); + } + } Expression::Cast { expr, .. } => { stack.push(RecursionProcessing::Call(expr)); } diff --git a/common/planners/src/plan_node.rs b/common/planners/src/plan_node.rs index 17c04f4b6011..9c8c3db6355c 100644 --- a/common/planners/src/plan_node.rs +++ b/common/planners/src/plan_node.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use common_datavalues::DataSchemaRef; use crate::plan_table_undrop::UnDropTablePlan; +use crate::plan_window_func::WindowFuncPlan; use crate::AggregatorFinalPlan; use crate::AggregatorPartialPlan; use crate::AlterClusterKeyPlan; @@ -90,6 +91,7 @@ pub enum PlanNode { AggregatorFinal(AggregatorFinalPlan), Filter(FilterPlan), Having(HavingPlan), + WindowFunc(WindowFuncPlan), Sort(SortPlan), Limit(LimitPlan), LimitBy(LimitByPlan), @@ -196,6 +198,7 @@ impl PlanNode { PlanNode::AggregatorFinal(v) => v.schema(), PlanNode::Filter(v) => v.schema(), PlanNode::Having(v) => v.schema(), + PlanNode::WindowFunc(v) => v.schema(), PlanNode::Limit(v) => v.schema(), PlanNode::LimitBy(v) => v.schema(), PlanNode::ReadSource(v) => v.schema(), @@ -301,6 +304,7 @@ impl PlanNode { PlanNode::AggregatorFinal(_) => "AggregatorFinalPlan", PlanNode::Filter(_) => "FilterPlan", PlanNode::Having(_) => "HavingPlan", + PlanNode::WindowFunc(_) => "WindowFuncPlan", PlanNode::Limit(_) => "LimitPlan", PlanNode::LimitBy(_) => "LimitByPlan", PlanNode::ReadSource(_) => "ReadSourcePlan", @@ -403,6 +407,7 @@ impl PlanNode { PlanNode::AggregatorFinal(v) => vec![v.input.clone()], PlanNode::Filter(v) => vec![v.input.clone()], PlanNode::Having(v) => vec![v.input.clone()], + PlanNode::WindowFunc(v) => vec![v.input.clone()], PlanNode::Limit(v) => vec![v.input.clone()], PlanNode::Explain(v) => vec![v.input.clone()], PlanNode::Select(v) => vec![v.input.clone()], diff --git a/common/planners/src/plan_node_builder.rs b/common/planners/src/plan_node_builder.rs index fbdea7b9ae1d..a47fe0ac13d1 100644 --- a/common/planners/src/plan_node_builder.rs +++ b/common/planners/src/plan_node_builder.rs @@ -20,6 +20,7 @@ use common_exception::Result; use crate::col; use crate::plan_subqueries_set::SubQueriesSetPlan; +use crate::plan_window_func::WindowFuncPlan; use crate::validate_expression; use crate::AggregatorFinalPlan; use crate::AggregatorPartialPlan; @@ -214,6 +215,22 @@ impl PlanBuilder { }))) } + /// Apply a window function + pub fn window_func(&self, expr: Expression) -> Result { + let window_func = expr.clone(); + let input = self.wrap_subquery_plan(&[expr])?; + let input_schema = input.schema(); + let mut input_fields = input_schema.fields().to_owned(); + let window_field = window_func.to_data_field(&input_schema).unwrap(); + input_fields.push(window_field); + let schema = Arc::new(DataSchema::new(input_fields)); + Ok(Self::from(&PlanNode::WindowFunc(WindowFuncPlan { + window_func, + input, + schema, + }))) + } + pub fn sort(&self, exprs: &[Expression]) -> Result { Ok(Self::from(&PlanNode::Sort(SortPlan { order_by: exprs.to_vec(), diff --git a/common/planners/src/plan_node_display_indent.rs b/common/planners/src/plan_node_display_indent.rs index 068ad8837e68..d3700d3a6327 100644 --- a/common/planners/src/plan_node_display_indent.rs +++ b/common/planners/src/plan_node_display_indent.rs @@ -73,6 +73,9 @@ impl<'a> fmt::Display for PlanNodeIndentFormatDisplay<'a> { PlanNode::AggregatorFinal(plan) => Self::format_aggregator_final(f, plan), PlanNode::Filter(plan) => write!(f, "Filter: {:?}", plan.predicate), PlanNode::Having(plan) => write!(f, "Having: {:?}", plan.predicate), + PlanNode::WindowFunc(plan) => { + write!(f, "WindowFunc: {:?}", plan.window_func) + } PlanNode::Sort(plan) => Self::format_sort(f, plan), PlanNode::Limit(plan) => Self::format_limit(f, plan), PlanNode::SubQueryExpression(plan) => Self::format_subquery_expr(f, plan), diff --git a/common/planners/src/plan_node_rewriter.rs b/common/planners/src/plan_node_rewriter.rs index 4b69db9f6c41..39dbc714642d 100644 --- a/common/planners/src/plan_node_rewriter.rs +++ b/common/planners/src/plan_node_rewriter.rs @@ -24,6 +24,7 @@ use common_exception::Result; use crate::plan_broadcast::BroadcastPlan; use crate::plan_subqueries_set::SubQueriesSetPlan; use crate::plan_table_undrop::UnDropTablePlan; +use crate::plan_window_func::WindowFuncPlan; use crate::AggregatorFinalPlan; use crate::AggregatorPartialPlan; use crate::AlterClusterKeyPlan; @@ -118,6 +119,7 @@ pub trait PlanRewriter: Sized { PlanNode::Broadcast(plan) => self.rewrite_broadcast(plan), PlanNode::Remote(plan) => self.rewrite_remote(plan), PlanNode::Having(plan) => self.rewrite_having(plan), + PlanNode::WindowFunc(plan) => self.rewrite_window_func(plan), PlanNode::Expression(plan) => self.rewrite_expression(plan), PlanNode::Sort(plan) => self.rewrite_sort(plan), PlanNode::Limit(plan) => self.rewrite_limit(plan), @@ -322,6 +324,14 @@ pub trait PlanRewriter: Sized { PlanBuilder::from(&new_input).having(new_predicate)?.build() } + fn rewrite_window_func(&mut self, plan: &WindowFuncPlan) -> Result { + let new_input = self.rewrite_plan_node(plan.input.as_ref())?; + let new_window_func = self.rewrite_expr(&new_input.schema(), &plan.window_func)?; + PlanBuilder::from(&new_input) + .window_func(new_window_func)? + .build() + } + fn rewrite_sort(&mut self, plan: &SortPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; let new_order_by = self.rewrite_exprs(&new_input.schema(), &plan.order_by)?; @@ -669,6 +679,44 @@ impl RewriteHelper { } } + Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + window_frame, + } => { + let new_args: Result> = args + .iter() + .map(|v| RewriteHelper::expr_rewrite_alias(v, data)) + .collect(); + + let new_partition_by: Result> = partition_by + .iter() + .map(|v| RewriteHelper::expr_rewrite_alias(v, data)) + .collect(); + + let new_order_by: Result> = order_by + .iter() + .map(|v| RewriteHelper::expr_rewrite_alias(v, data)) + .collect(); + + match (new_args, new_partition_by, new_order_by) { + (Ok(new_args), Ok(new_partition_by), Ok(new_order_by)) => { + Ok(Expression::WindowFunction { + op: op.clone(), + params: params.clone(), + args: new_args, + partition_by: new_partition_by, + order_by: new_order_by, + window_frame: *window_frame, + }) + } + (Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => Err(e), + } + } + Expression::Alias(alias, plan) => { if data.inside_aliases.contains(alias) { return Result::Err(ErrorCode::SyntaxException(format!( @@ -779,6 +827,17 @@ impl RewriteHelper { } Expression::ScalarFunction { args, .. } => args.clone(), Expression::AggregateFunction { args, .. } => args.clone(), + Expression::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + let mut v = args.clone(); + v.extend(partition_by.clone()); + v.extend(order_by.clone()); + v + } Expression::Wildcard => vec![], Expression::Sort { expr, .. } => vec![expr.as_ref().clone()], Expression::Cast { expr, .. } => vec![expr.as_ref().clone()], @@ -818,6 +877,24 @@ impl RewriteHelper { } v } + Expression::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + let mut v = vec![]; + for arg_expr in args { + v.append(&mut Self::expression_plan_columns(arg_expr)?) + } + for part_by_expr in partition_by { + v.append(&mut Self::expression_plan_columns(part_by_expr)?) + } + for order_by_expr in order_by { + v.append(&mut Self::expression_plan_columns(order_by_expr)?) + } + v + } Expression::Wildcard => vec![], Expression::Sort { expr, .. } => Self::expression_plan_columns(expr)?, Expression::Cast { expr, .. } => Self::expression_plan_columns(expr)?, diff --git a/common/planners/src/plan_node_visitor.rs b/common/planners/src/plan_node_visitor.rs index d2ae626c39c4..e705c7d3fc92 100644 --- a/common/planners/src/plan_node_visitor.rs +++ b/common/planners/src/plan_node_visitor.rs @@ -17,6 +17,7 @@ use common_exception::Result; use crate::plan_broadcast::BroadcastPlan; use crate::plan_subqueries_set::SubQueriesSetPlan; use crate::plan_table_undrop::UnDropTablePlan; +use crate::plan_window_func::WindowFuncPlan; use crate::AggregatorFinalPlan; use crate::AggregatorPartialPlan; use crate::AlterClusterKeyPlan; @@ -131,6 +132,7 @@ pub trait PlanVisitor { PlanNode::Broadcast(plan) => self.visit_broadcast(plan), PlanNode::Remote(plan) => self.visit_remote(plan), PlanNode::Having(plan) => self.visit_having(plan), + PlanNode::WindowFunc(plan) => self.visit_window_func(plan), PlanNode::Expression(plan) => self.visit_expression(plan), PlanNode::Limit(plan) => self.visit_limit(plan), PlanNode::LimitBy(plan) => self.visit_limit_by(plan), @@ -299,6 +301,11 @@ pub trait PlanVisitor { self.visit_expr(&plan.predicate) } + fn visit_window_func(&mut self, plan: &WindowFuncPlan) -> Result<()> { + self.visit_plan_node(plan.input.as_ref())?; + self.visit_expr(&plan.window_func) + } + fn visit_sort(&mut self, plan: &SortPlan) -> Result<()> { self.visit_plan_node(plan.input.as_ref())?; self.visit_exprs(&plan.order_by) diff --git a/common/planners/src/plan_window_func.rs b/common/planners/src/plan_window_func.rs new file mode 100644 index 000000000000..24ce6a8b3b11 --- /dev/null +++ b/common/planners/src/plan_window_func.rs @@ -0,0 +1,40 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_datavalues::DataSchemaRef; + +use crate::Expression; +use crate::PlanNode; + +#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq)] +pub struct WindowFuncPlan { + /// The window function expression + pub window_func: Expression, + /// The incoming logical plan + pub input: Arc, + /// output schema + pub schema: DataSchemaRef, +} + +impl WindowFuncPlan { + pub fn schema(&self) -> DataSchemaRef { + self.schema.clone() + } + + pub fn set_input(&mut self, node: &PlanNode) { + self.input = Arc::new(node.clone()); + } +} diff --git a/query/Cargo.toml b/query/Cargo.toml index fce50401fc46..b936aa619cc0 100644 --- a/query/Cargo.toml +++ b/query/Cargo.toml @@ -78,6 +78,7 @@ chrono-tz = "0.6.1" clap = { version = "3.1.8", features = ["derive", "env"] } dyn-clone = "1.0.5" enum_dispatch = "0.3.8" +enum_extract = "0.1.1" futures = "0.3.21" headers = "0.3.7" http = "0.2.6" @@ -101,6 +102,7 @@ rand = "0.8.5" regex = "1.5.5" reqwest = "0.11.10" rsa = "0.5.0" +segment-tree = "2.0.0" semver = "1.0.9" serde = { version = "1.0.136", features = ["derive"] } serde-bridge = "0.0.3" diff --git a/query/src/common/expression_evaluator.rs b/query/src/common/expression_evaluator.rs index 4b20b2eabecb..b5e8cc055c6d 100644 --- a/query/src/common/expression_evaluator.rs +++ b/query/src/common/expression_evaluator.rs @@ -105,6 +105,10 @@ impl ExpressionEvaluator { "Unsupported AggregateFunction scalar expression", )), + Expression::WindowFunction { .. } => Err(ErrorCode::LogicalError( + "Unsupported WindowFunction scalar expression", + )), + Expression::Sort { .. } => Err(ErrorCode::LogicalError( "Unsupported Sort scalar expression", )), diff --git a/query/src/interpreters/interpreter_select.rs b/query/src/interpreters/interpreter_select.rs index 78e25baf3907..92b5793b223e 100644 --- a/query/src/interpreters/interpreter_select.rs +++ b/query/src/interpreters/interpreter_select.rs @@ -78,12 +78,15 @@ impl Interpreter for SelectInterpreter { { let async_runtime = self.ctx.get_storage_runtime(); let new_pipeline = self.create_new_pipeline()?; + tracing::debug!("get new_pipeline:\n{:?}", new_pipeline); let executor = PipelinePullingExecutor::try_create(async_runtime, new_pipeline)?; let (handler, stream) = ProcessorExecutorStream::create(executor)?; self.ctx.add_source_abort_handle(handler); return Ok(Box::pin(stream)); } let optimized_plan = self.rewrite_plan()?; + tracing::debug!("get optimized plan:\n{:?}", optimized_plan); + plan_schedulers::schedule_query(&self.ctx, &optimized_plan).await } diff --git a/query/src/interpreters/plan_schedulers/plan_scheduler.rs b/query/src/interpreters/plan_schedulers/plan_scheduler.rs index 064e4b412565..0a07af5c03c9 100644 --- a/query/src/interpreters/plan_schedulers/plan_scheduler.rs +++ b/query/src/interpreters/plan_schedulers/plan_scheduler.rs @@ -42,6 +42,7 @@ use common_planners::SortPlan; use common_planners::StageKind; use common_planners::StagePlan; use common_planners::SubQueriesSetPlan; +use common_planners::WindowFuncPlan; use common_tracing::tracing; use crate::api::BroadcastAction; @@ -296,6 +297,7 @@ impl PlanScheduler { match node { PlanNode::AggregatorPartial(plan) => self.visit_aggr_part(plan, tasks), PlanNode::AggregatorFinal(plan) => self.visit_aggr_final(plan, tasks), + PlanNode::WindowFunc(plan) => self.visit_window_func(plan, tasks), PlanNode::Empty(plan) => self.visit_empty(plan, tasks), PlanNode::Projection(plan) => self.visit_projection(plan, tasks), PlanNode::Filter(plan) => self.visit_filter(plan, tasks), @@ -353,6 +355,33 @@ impl PlanScheduler { Ok(()) } + fn visit_window_func(&mut self, plan: &WindowFuncPlan, tasks: &mut Tasks) -> Result<()> { + self.visit_plan_node(plan.input.as_ref(), tasks)?; + match self.running_mode { + RunningMode::Cluster => self.visit_cluster_window_func(plan), + RunningMode::Standalone => self.visit_local_window_func(plan), + } + Ok(()) + } + + fn visit_cluster_window_func(&mut self, plan: &WindowFuncPlan) { + for index in 0..self.nodes_plan.len() { + self.nodes_plan[index] = PlanNode::WindowFunc(WindowFuncPlan { + window_func: plan.window_func.clone(), + input: Arc::new(self.nodes_plan[index].clone()), + schema: plan.schema.clone(), + }) + } + } + + fn visit_local_window_func(&mut self, plan: &WindowFuncPlan) { + self.nodes_plan[self.local_pos] = PlanNode::WindowFunc(WindowFuncPlan { + window_func: plan.window_func.clone(), + input: Arc::new(self.nodes_plan[self.local_pos].clone()), + schema: plan.schema(), + }); + } + fn visit_local_aggr_final(&mut self, plan: &AggregatorFinalPlan) { self.nodes_plan[self.local_pos] = PlanNode::AggregatorFinal(AggregatorFinalPlan { schema: plan.schema.clone(), diff --git a/query/src/interpreters/plan_schedulers/plan_scheduler_query.rs b/query/src/interpreters/plan_schedulers/plan_scheduler_query.rs index 2ee2930b5aa2..f355822ad9a5 100644 --- a/query/src/interpreters/plan_schedulers/plan_scheduler_query.rs +++ b/query/src/interpreters/plan_schedulers/plan_scheduler_query.rs @@ -50,6 +50,7 @@ pub async fn schedule_query( let pipeline_builder = PipelineBuilder::create(ctx.clone()); let mut in_local_pipeline = pipeline_builder.build(&scheduled_tasks.get_local_task())?; + tracing::log::debug!("local_pipeline:\n{:?}", in_local_pipeline); match in_local_pipeline.execute().await { Ok(stream) => Ok(ScheduledStream::create(ctx.clone(), scheduled, stream)), diff --git a/query/src/optimizers/optimizer_scatters.rs b/query/src/optimizers/optimizer_scatters.rs index 373da3caae1e..d76b0fee96ab 100644 --- a/query/src/optimizers/optimizer_scatters.rs +++ b/query/src/optimizers/optimizer_scatters.rs @@ -31,6 +31,8 @@ use common_planners::ReadDataSourcePlan; use common_planners::SortPlan; use common_planners::StageKind; use common_planners::StagePlan; +use common_planners::WindowFuncPlan; +use enum_extract::let_extract; use crate::optimizers::Optimizer; use crate::sessions::QueryContext; @@ -109,6 +111,66 @@ impl ScattersOptimizerImpl { } } + fn cluster_window(&mut self, plan: &WindowFuncPlan) -> Result { + match self.input.take() { + None => Err(ErrorCode::LogicalError("Cluster window input is None")), + Some(input) => { + let_extract!( + Expression::WindowFunction { + op: _op, + params: _params, + args: _args, + partition_by: partition_by, + order_by: _order_by, + window_frame: _window_frame + }, + &plan.window_func, + panic!() + ); + + let stage_input = if !partition_by.is_empty() { + let mut concat_ws_args = vec![Expression::create_literal(DataValue::String( + "#".as_bytes().to_vec(), + ))]; + concat_ws_args.extend(partition_by.to_owned()); + let concat_partition_by = + Expression::create_scalar_function("concat_ws", concat_ws_args); + + let scatters_expr = + Expression::create_scalar_function("sipHash", vec![concat_partition_by]); + + PlanNode::Stage(StagePlan { + scatters_expr, + kind: StageKind::Normal, + input, + }) + } else { + self.running_mode = RunningMode::Standalone; + PlanNode::Stage(StagePlan { + scatters_expr: Expression::create_literal(DataValue::UInt64(0)), + kind: StageKind::Convergent, + input, + }) + }; + + Ok(PlanNode::WindowFunc(WindowFuncPlan { + window_func: plan.window_func.to_owned(), + input: Arc::new(stage_input), + schema: plan.schema.to_owned(), + })) + } + } + } + + fn standalone_window(&mut self, plan: &WindowFuncPlan) -> Result { + match self.input.take() { + None => Err(ErrorCode::LogicalError("Standalone window input is None")), + Some(input) => PlanBuilder::from(input.as_ref()) + .window_func(plan.window_func.to_owned())? + .build(), + } + } + fn cluster_sort(&mut self, plan: &SortPlan) -> Result { // Order by we convergent it in local node self.running_mode = RunningMode::Standalone; @@ -251,6 +313,17 @@ impl PlanRewriter for ScattersOptimizerImpl { } } + fn rewrite_window_func(&mut self, plan: &WindowFuncPlan) -> Result { + let new_input = Arc::new(self.rewrite_plan_node(&plan.input)?); + + self.input = Some(new_input); + + match self.running_mode { + RunningMode::Cluster => self.cluster_window(plan), + RunningMode::Standalone => self.standalone_window(plan), + } + } + fn rewrite_sort(&mut self, plan: &SortPlan) -> Result { self.input = Some(Arc::new(self.rewrite_plan_node(plan.input.as_ref())?)); diff --git a/query/src/pipelines/new/pipeline_builder.rs b/query/src/pipelines/new/pipeline_builder.rs index f4b911500824..e032539441b5 100644 --- a/query/src/pipelines/new/pipeline_builder.rs +++ b/query/src/pipelines/new/pipeline_builder.rs @@ -30,6 +30,7 @@ use common_planners::ReadDataSourcePlan; use common_planners::SelectPlan; use common_planners::SortPlan; use common_planners::SubQueriesSetPlan; +use common_planners::WindowFuncPlan; use super::processors::SortMergeCompactor; use crate::pipelines::new::pipeline::NewPipeline; @@ -86,6 +87,7 @@ impl PlanVisitor for QueryPipelineBuilder { PlanNode::Expression(n) => self.visit_expression(n), PlanNode::AggregatorPartial(n) => self.visit_aggregate_partial(n), PlanNode::AggregatorFinal(n) => self.visit_aggregate_final(n), + PlanNode::WindowFunc(n) => self.visit_window_func(n), PlanNode::Filter(n) => self.visit_filter(n), PlanNode::Having(n) => self.visit_having(n), PlanNode::Sort(n) => self.visit_sort(n), @@ -147,6 +149,11 @@ impl PlanVisitor for QueryPipelineBuilder { }) } + fn visit_window_func(&mut self, plan: &WindowFuncPlan) -> Result<()> { + self.visit_plan_node(&plan.input)?; + unimplemented!("window function cannot work in new scheduler framework now") + } + fn visit_projection(&mut self, plan: &ProjectionPlan) -> Result<()> { self.visit_plan_node(&plan.input)?; self.pipeline diff --git a/query/src/pipelines/processors/pipeline_builder.rs b/query/src/pipelines/processors/pipeline_builder.rs index 5ead93a4a1bf..033ad05871db 100644 --- a/query/src/pipelines/processors/pipeline_builder.rs +++ b/query/src/pipelines/processors/pipeline_builder.rs @@ -32,6 +32,7 @@ use common_planners::SinkPlan; use common_planners::SortPlan; use common_planners::StagePlan; use common_planners::SubQueriesSetPlan; +use common_planners::WindowFuncPlan; use common_tracing::tracing; use crate::api::FlightTicket; @@ -53,6 +54,7 @@ use crate::pipelines::transforms::SortPartialTransform; use crate::pipelines::transforms::SourceTransform; use crate::pipelines::transforms::SubQueriesPuller; use crate::pipelines::transforms::WhereTransform; +use crate::pipelines::transforms::WindowFuncTransform; use crate::sessions::QueryContext; pub struct PipelineBuilder { @@ -89,6 +91,7 @@ impl PipelineBuilder { PlanNode::Projection(node) => self.visit_projection(node), PlanNode::AggregatorPartial(node) => self.visit_aggregator_partial(node), PlanNode::AggregatorFinal(node) => self.visit_aggregator_final(node), + PlanNode::WindowFunc(node) => self.visit_window_func(node), PlanNode::Filter(node) => self.visit_filter(node), PlanNode::Having(node) => self.visit_having(node), PlanNode::Sort(node) => self.visit_sort(node), @@ -216,6 +219,18 @@ impl PipelineBuilder { Ok(pipeline) } + fn visit_window_func(&mut self, node: &WindowFuncPlan) -> Result { + let mut pipeline = self.visit(&*node.input)?; + pipeline.add_simple_transform(|| { + Ok(Box::new(WindowFuncTransform::create( + node.window_func.to_owned(), + node.schema.to_owned(), + node.input.schema(), + ))) + })?; + Ok(pipeline) + } + fn visit_filter(&mut self, node: &FilterPlan) -> Result { let mut pipeline = self.visit(&*node.input)?; pipeline.add_simple_transform(|| { diff --git a/query/src/pipelines/transforms/mod.rs b/query/src/pipelines/transforms/mod.rs index b32dc65bfa2e..416ec3ab0847 100644 --- a/query/src/pipelines/transforms/mod.rs +++ b/query/src/pipelines/transforms/mod.rs @@ -27,6 +27,7 @@ mod transform_remote; mod transform_sort_merge; mod transform_sort_partial; mod transform_source; +mod transform_window_func; pub mod group_by; mod streams; @@ -52,3 +53,4 @@ pub use transform_sort_merge::SortMergeTransform; pub use transform_sort_partial::get_sort_descriptions; pub use transform_sort_partial::SortPartialTransform; pub use transform_source::SourceTransform; +pub use transform_window_func::WindowFuncTransform; diff --git a/query/src/pipelines/transforms/transform_window_func.rs b/query/src/pipelines/transforms/transform_window_func.rs new file mode 100644 index 000000000000..f31c1db7884b --- /dev/null +++ b/query/src/pipelines/transforms/transform_window_func.rs @@ -0,0 +1,568 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; + +use bumpalo::Bump; +use common_arrow::arrow::array::ArrayRef; +use common_arrow::arrow::compute::partition::lexicographical_partition_ranges; +use common_arrow::arrow::compute::sort::SortColumn; +use common_datablocks::DataBlock; +use common_datavalues::ColumnRef; +use common_datavalues::ColumnWithField; +use common_datavalues::DataField; +use common_datavalues::DataSchemaRef; +use common_datavalues::DataType; +use common_datavalues::DataValue; +use common_datavalues::Series; +use common_functions::aggregates::AggregateFunctionFactory; +use common_functions::aggregates::AggregateFunctionRef; +use common_functions::aggregates::StateAddr; +use common_functions::scalars::assert_numeric; +use common_functions::window::WindowFrame; +use common_functions::window::WindowFrameBound; +use common_functions::window::WindowFrameUnits; +use common_planners::Expression; +use common_streams::DataBlockStream; +use common_streams::SendableDataBlockStream; +use common_tracing::tracing; +use enum_extract::let_extract; +use futures::StreamExt; +use segment_tree::ops::Commutative; +use segment_tree::ops::Identity; +use segment_tree::ops::Operation; +use segment_tree::SegmentPoint; + +use crate::pipelines::processors::EmptyProcessor; +use crate::pipelines::processors::Processor; +use crate::pipelines::transforms::get_sort_descriptions; + +pub struct WindowFuncTransform { + window_func: Expression, + schema: DataSchemaRef, + input: Arc, + input_schema: DataSchemaRef, +} + +impl WindowFuncTransform { + pub fn create( + window_func: Expression, + schema: DataSchemaRef, + input_schema: DataSchemaRef, + ) -> Self { + WindowFuncTransform { + window_func, + schema, + input: Arc::new(EmptyProcessor::create()), + input_schema, + } + } + + /// evaluate window function for each frame and return the result block + async fn evaluate_window_func(&self, block: &DataBlock) -> common_exception::Result { + // extract the window function + let_extract!( + Expression::WindowFunction { + op, + params, + args, + partition_by, + order_by, + window_frame + }, + &self.window_func, + panic!() + ); + + // sort block by partition_by and order_by exprs + let mut partition_sort_exprs: Vec = + Vec::with_capacity(partition_by.len() + order_by.len()); + partition_sort_exprs.extend( + partition_by + .iter() + .map(|part_by_expr| Expression::Sort { + expr: Box::new(part_by_expr.to_owned()), + asc: true, + nulls_first: false, + origin_expr: Box::new(part_by_expr.to_owned()), + }) + .collect::>(), + ); + partition_sort_exprs.extend(order_by.to_owned()); + + let block = if !partition_sort_exprs.is_empty() { + let sort_column_desc = get_sort_descriptions(block.schema(), &partition_sort_exprs)?; + DataBlock::sort_block(block, &sort_column_desc, None)? + } else { + block.to_owned() + }; + + // set default window frame + let window_frame = match window_frame { + None => { + // compute the whole partition + match order_by.is_empty() { + // RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + false => WindowFrame::default(), + // RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + true => WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::Following(None), + }, + } + } + Some(window_frame) => window_frame.to_owned(), + }; + + // determine window frame bounds of each tuple + let frame_bounds = Self::frame_bounds_per_tuple( + &block, + partition_by, + order_by, + window_frame.units, + window_frame.start_bound, + window_frame.end_bound, + ); + + // function calculate + let mut arguments: Vec = Vec::with_capacity(args.len()); + let mut arg_fields = Vec::with_capacity(args.len()); + for arg in args { + let arg_field: DataField = arg.to_data_field(&self.input_schema)?; + arg_fields.push(arg_field.clone()); + let arg_column = block.try_column_by_name(arg_field.name()).unwrap(); + arguments.push(ColumnWithField::new(Arc::clone(arg_column), arg_field)); + } + + let function = if !AggregateFunctionFactory::instance().check(op) { + unimplemented!("not yet impl built-in window func"); + } else { + AggregateFunctionFactory::instance().get(op, params.clone(), arg_fields)? + }; + + let arena = Arc::new(Bump::with_capacity( + 2 * block.num_rows() * function.state_layout().size(), + )); + let segment_tree_state = Self::new_partition_aggr_func( + function.clone(), + &arguments + .iter() + .map(|a| a.column().clone()) + .collect::>(), + arena, + ); + + let window_col_per_tuple = (0..block.num_rows()) + .map(|i| { + let frame = &frame_bounds[i]; + let frame_start = frame.start; + let frame_end = frame.end; + let state = segment_tree_state.query(frame_start, frame_end); + let mut builder = function.return_type().unwrap().create_mutable(1); + function.merge_result(state, builder.as_mut()).unwrap(); + builder.to_column() + }) + .collect::>(); + + let window_col = Series::concat(&window_col_per_tuple).unwrap(); + + block.add_column( + window_col, + self.window_func.to_data_field(&self.input_schema).unwrap(), + ) + } + + /// compute frame range for each tuple + fn frame_bounds_per_tuple( + block: &DataBlock, + partition_by: &[Expression], + order_by: &[Expression], + frame_units: WindowFrameUnits, + start: WindowFrameBound, + end: WindowFrameBound, + ) -> Vec> { + let partition_by_arrow_array = partition_by + .iter() + .map(|expr| { + block + .try_column_by_name(&expr.column_name()) + .unwrap() + .as_arrow_array() + }) + .collect::>(); + let partition_by_sort_column = partition_by_arrow_array + .iter() + .map(|array| SortColumn { + values: array.as_ref(), + options: None, + }) + .collect::>(); + + let partition_boundaries = if !partition_by_sort_column.is_empty() { + lexicographical_partition_ranges(&partition_by_sort_column) + .unwrap() + .collect::>() + } else { + vec![0..block.num_rows(); 1] + }; + let mut partition_boundaries = partition_boundaries.into_iter(); + let mut partition = partition_boundaries.next().unwrap(); + + match (frame_units, start, end) { + (_, WindowFrameBound::Preceding(None), WindowFrameBound::Following(None)) => (0..block + .num_rows()) + .map(|i| { + if i >= partition.end && i < block.num_rows() { + partition = partition_boundaries.next().unwrap(); + } + partition.clone() + }) + .collect::>(), + (WindowFrameUnits::Rows, frame_start, frame_end) => (0..block.num_rows()) + .map(|i| { + if i >= partition.end && i < block.num_rows() { + partition = partition_boundaries.next().unwrap(); + } + let mut start = partition.start; + let mut end = partition.end; + match frame_start { + WindowFrameBound::Preceding(Some(preceding)) => { + start = std::cmp::max( + start, + if i < preceding as usize { + 0 + } else { + i - preceding as usize + }, + ); + } + WindowFrameBound::CurrentRow => { + start = i; + } + WindowFrameBound::Following(Some(following)) => { + start = std::cmp::min(end, i + following as usize) + } + _ => (), + } + match frame_end { + WindowFrameBound::Preceding(Some(preceding)) => { + end = std::cmp::max(start, i + 1 - preceding as usize); + } + WindowFrameBound::CurrentRow => { + end = i + 1; + } + WindowFrameBound::Following(Some(following)) => { + end = std::cmp::min(end, i + 1 + following as usize); + } + _ => (), + } + start..end + }) + .collect::>(), + (WindowFrameUnits::Range, frame_start, frame_end) => match (frame_start, frame_end) { + (WindowFrameBound::Preceding(None), WindowFrameBound::CurrentRow) => { + let mut partition_by_sort_column = partition_by_sort_column; + let order_by_arrow_array = order_by + .iter() + .map(|expr| { + block + .try_column_by_name(&expr.column_name()) + .unwrap() + .as_arrow_array() + }) + .collect::>(); + let order_by_sort_column = order_by_arrow_array + .iter() + .map(|array| SortColumn { + values: array.as_ref(), + options: None, + }) + .collect::>(); + partition_by_sort_column.extend(order_by_sort_column); + + let mut peered_boundaries = + lexicographical_partition_ranges(&partition_by_sort_column).unwrap(); + let mut peered = peered_boundaries.next().unwrap(); + (0..block.num_rows()) + .map(|i| { + if i >= partition.end && i < block.num_rows() { + partition = partition_boundaries.next().unwrap(); + } + if i >= peered.end && i < block.num_rows() { + peered = peered_boundaries.next().unwrap(); + } + partition.start..peered.end + }) + .collect::>() + } + (WindowFrameBound::CurrentRow, WindowFrameBound::Following(None)) => { + let mut partition_by_sort_column = partition_by_sort_column; + let order_by_arrow_array = order_by + .iter() + .map(|expr| { + block + .try_column_by_name(&expr.column_name()) + .unwrap() + .as_arrow_array() + }) + .collect::>(); + let order_by_sort_column = order_by_arrow_array + .iter() + .map(|array| SortColumn { + values: array.as_ref(), + options: None, + }) + .collect::>(); + partition_by_sort_column.extend(order_by_sort_column); + + let mut peered_boundaries = + lexicographical_partition_ranges(&partition_by_sort_column).unwrap(); + let mut peered = peered_boundaries.next().unwrap(); + (0..block.num_rows()) + .map(|i| { + if i >= partition.end && i < block.num_rows() { + partition = partition_boundaries.next().unwrap(); + } + if i >= peered.end && i < block.num_rows() { + peered = peered_boundaries.next().unwrap(); + } + peered.start..partition.end + }) + .collect::>() + } + (frame_start, frame_end) => { + assert_eq!( + order_by.len(), + 1, + "Range mode is only possible if the query has exactly one numeric order by expression." + ); + assert_numeric( + block + .schema() + .field_with_name(&order_by[0].column_name()) + .unwrap() + .data_type(), + ) + .expect( + "Range mode is only possible if the query has exactly one numeric order by expression.", + ); + + let order_by_values = block + .try_column_by_name(&order_by[0].column_name()) + .unwrap() + .to_values(); + + (0..block.num_rows()) + .map(|i| { + let current_value = &order_by_values[i]; + if i >= partition.end && i < block.num_rows() { + partition = partition_boundaries.next().unwrap(); + } + let mut start = partition.start; + let mut end = partition.end; + match frame_start { + WindowFrameBound::Preceding(Some(preceding)) => { + let offset = order_by_values[start..end].partition_point(|x| { + let min = DataValue::from( + current_value.as_f64().unwrap() - preceding as f64, + ); + x.cmp(&min) == Ordering::Less + }); + start += offset; + } + WindowFrameBound::CurrentRow => { + let offset = order_by_values[start..end].partition_point(|x| { + let peer = DataValue::from(current_value.as_f64().unwrap()); + x.cmp(&peer) == Ordering::Less + }); + start += offset; + } + WindowFrameBound::Following(Some(following)) => { + let offset = order_by_values[start..end].partition_point(|x| { + let max = DataValue::from( + current_value.as_f64().unwrap() + following as f64, + ); + x.cmp(&max) == Ordering::Less + }); + start += offset; + } + _ => (), + } + match frame_end { + WindowFrameBound::Preceding(Some(preceding)) => { + let offset = order_by_values[start..end].partition_point(|x| { + let min = DataValue::from( + current_value.as_f64().unwrap() - preceding as f64, + ); + x.cmp(&min) != Ordering::Greater + }); + end = start + offset; + } + WindowFrameBound::CurrentRow => { + let offset = order_by_values[start..end].partition_point(|x| { + let peer = DataValue::from(current_value.as_f64().unwrap()); + x.cmp(&peer) != Ordering::Greater + }); + end = start + offset; + } + WindowFrameBound::Following(Some(following)) => { + let offset = order_by_values[start..end].partition_point(|x| { + let max = DataValue::from( + current_value.as_f64().unwrap() + following as f64, + ); + x.cmp(&max) != Ordering::Greater + }); + end = start + offset; + } + _ => (), + } + start..end + }) + .collect::>() + } + }, + } + } + + /// Actually, the computation is row based + fn new_partition_aggr_func( + func: AggregateFunctionRef, + arguments: &[ColumnRef], + arena: Arc, + ) -> SegmentPoint { + let rows = arguments[0].len(); + let state_per_tuple = (0..rows) + .map(|i| { + arguments + .iter() + .map(|c| c.slice(i, 1)) + .collect::>() + }) + .map(|args| { + let place = arena.alloc_layout(func.state_layout()); + let state_addr = place.into(); + func.init_state(state_addr); + func.accumulate(state_addr, &args, None, 1) + .expect("Failed to initialize the state"); + state_addr + }) + .collect::>(); + SegmentPoint::build(state_per_tuple, Agg { func, arena }) + } +} + +#[async_trait::async_trait] +impl Processor for WindowFuncTransform { + fn name(&self) -> &str { + "WindowFuncTransform" + } + + fn connect_to(&mut self, input: Arc) -> common_exception::Result<()> { + self.input = input; + Ok(()) + } + + fn inputs(&self) -> Vec> { + vec![self.input.clone()] + } + + fn as_any(&self) -> &dyn Any { + self + } + + #[tracing::instrument(level = "debug", name = "window_func_execute", skip(self))] + async fn execute(&self) -> common_exception::Result { + let mut stream: SendableDataBlockStream = self.input.execute().await?; + let mut blocks: Vec = vec![]; + while let Some(block) = stream.next().await { + let block = block?; + blocks.push(block); + } + + if blocks.is_empty() { + return Ok(Box::pin(DataBlockStream::create( + self.schema.clone(), + None, + vec![], + ))); + } + + // combine blocks + let schema = blocks[0].schema(); + + let combined_columns = (0..schema.num_fields()) + .map(|i| { + blocks + .iter() + .map(|block| block.column(i).clone()) + .collect::>() + }) + .map(|columns| Series::concat(&columns).unwrap()) + .collect::>(); + + let block = DataBlock::create(schema.clone(), combined_columns); + + // evaluate the window function column + let block = self.evaluate_window_func(&block).await.unwrap(); + + Ok(Box::pin(DataBlockStream::create( + self.schema.clone(), + None, + vec![block], + ))) + } +} + +struct Agg { + func: AggregateFunctionRef, + arena: Arc, +} + +impl Operation for Agg { + fn combine(&self, a: &StateAddr, b: &StateAddr) -> StateAddr { + let place = self.arena.alloc_layout(self.func.state_layout()); + let state_addr = place.into(); + self.func.init_state(state_addr); + self.func + .merge(state_addr, *a) + .expect("Failed to merge states"); + self.func + .merge(state_addr, *b) + .expect("Failed to merge states"); + state_addr + } + + fn combine_mut(&self, a: &mut StateAddr, b: &StateAddr) { + self.func.merge(*a, *b).expect("Failed to merge states"); + } + + fn combine_mut2(&self, a: &StateAddr, b: &mut StateAddr) { + self.func.merge(*b, *a).expect("Failed to merge states"); + } +} + +impl Commutative for Agg {} + +impl Identity for Agg { + fn identity(&self) -> StateAddr { + let place = self.arena.alloc_layout(self.func.state_layout()); + let state_addr = place.into(); + self.func.init_state(state_addr); + state_addr + } +} diff --git a/query/src/sql/plan_parser.rs b/query/src/sql/plan_parser.rs index fd9a962a079e..372f858dc52a 100644 --- a/query/src/sql/plan_parser.rs +++ b/query/src/sql/plan_parser.rs @@ -129,7 +129,10 @@ impl PlanParser { let having = Self::build_having_plan(before_order, data)?; tracing::debug!("Build having plan:\n{:?}", having); - let distinct = Self::build_distinct_plan(having, data)?; + let window = Self::build_window_plan(having, data)?; + tracing::debug!("Build window plan node:\n{:?}", window); + + let distinct = Self::build_distinct_plan(window, data)?; tracing::debug!("Build distinct plan:\n{:?}", distinct); let order_by = Self::build_order_by_plan(distinct, data)?; @@ -255,6 +258,18 @@ impl PlanParser { } } + fn build_window_plan(plan: PlanNode, data: &QueryAnalyzeState) -> Result { + match data.window_expressions.is_empty() { + true => Ok(plan), + false => { + let exprs = data.window_expressions.to_vec(); + exprs.into_iter().try_fold(plan, |input, window_func| { + PlanBuilder::from(&input).window_func(window_func)?.build() + }) + } + } + } + fn build_projection_plan(plan: PlanNode, data: &QueryAnalyzeState) -> Result { PlanBuilder::from(&plan) .project(&data.projection_expressions)? diff --git a/query/src/sql/statements/analyzer_expr.rs b/query/src/sql/statements/analyzer_expr.rs index c8d8b1f4a417..452f69a4411d 100644 --- a/query/src/sql/statements/analyzer_expr.rs +++ b/query/src/sql/statements/analyzer_expr.rs @@ -36,6 +36,7 @@ use sqlparser::ast::Ident; use sqlparser::ast::Query; use sqlparser::ast::UnaryOperator; use sqlparser::ast::Value; +use sqlparser::ast::WindowSpec; use crate::procedures::ContextFunction; use crate::sessions::QueryContext; @@ -184,28 +185,33 @@ impl ExpressionAnalyzer { } fn analyze_function(&self, info: &FunctionExprInfo, args: &mut Vec) -> Result<()> { - let mut arguments = Vec::with_capacity(info.args_count); - for _ in 0..info.args_count { - match args.pop() { - None => { - return Err(ErrorCode::LogicalError("It's a bug.")); - } - Some(arg) => { - arguments.insert(0, arg); + let func = match &info.over { + Some(_) => self.window_function(info, args)?, + None => { + let mut arguments = Vec::with_capacity(info.args_count); + for _ in 0..info.args_count { + match args.pop() { + None => { + return Err(ErrorCode::LogicalError("It's a bug.")); + } + Some(arg) => { + arguments.insert(0, arg); + } + } } + + match AggregateFunctionFactory::instance().check(&info.name) { + true => self.aggr_function(info, &arguments), + false => match info.kind { + OperatorKind::Unary => Self::unary_function(info, &arguments), + OperatorKind::Binary => Self::binary_function(info, &arguments), + OperatorKind::Other => self.other_function(info, &arguments), + }, + }? } - } + }; - args.push( - match AggregateFunctionFactory::instance().check(&info.name) { - true => self.aggr_function(info, &arguments), - false => match info.kind { - OperatorKind::Unary => Self::unary_function(info, &arguments), - OperatorKind::Binary => Self::binary_function(info, &arguments), - OperatorKind::Other => self.other_function(info, &arguments), - }, - }?, - ); + args.push(func); Ok(()) } @@ -300,6 +306,93 @@ impl ExpressionAnalyzer { } } + fn window_function( + &self, + info: &FunctionExprInfo, + args: &mut Vec, + ) -> Result { + // window function partition by and order by args + let mut partition_by = vec![]; + let mut order_by_expr = vec![]; + if let Some(window_spec) = &info.over { + let order_by_args_count = window_spec.order_by.len(); + let partition_by_args_count = window_spec.partition_by.len(); + + for i in 0..partition_by_args_count + order_by_args_count { + match args.pop() { + None => { + return Err(ErrorCode::LogicalError("It's a bug.")); + } + Some(arg) => { + if i < order_by_args_count { + order_by_expr.insert(0, arg); + } else { + partition_by.insert(0, arg); + } + } + } + } + } + + let mut parameters = Vec::with_capacity(info.parameters.len()); + + for parameter in &info.parameters { + match ValueExprAnalyzer::analyze( + parameter, + self.context.get_current_session().get_type(), + )? { + Expression::Literal { value, .. } => { + parameters.push(value); + } + expr => { + return Err(ErrorCode::SyntaxException(format!( + "Unsupported value expression: {:?}, must be datavalue", + expr + ))); + } + }; + } + + let mut arguments = Vec::with_capacity(info.args_count); + for _ in 0..info.args_count { + match args.pop() { + None => { + return Err(ErrorCode::LogicalError("It's a bug.")); + } + Some(arg) => { + arguments.insert(0, arg); + } + } + } + + let window_spec = info.over.as_ref().unwrap(); + + let order_by: Vec = order_by_expr + .into_iter() + .zip(window_spec.order_by.clone()) + .map(|(expr_sort_on, parser_sort_expr)| Expression::Sort { + expr: Box::new(expr_sort_on.clone()), + asc: parser_sort_expr.asc.unwrap_or(true), + nulls_first: parser_sort_expr.nulls_first.unwrap_or(true), + origin_expr: Box::new(expr_sort_on), + }) + .collect(); + + let window_frame = window_spec + .window_frame + .clone() + .map(|wf| wf.try_into().unwrap()); + + Ok(Expression::WindowFunction { + op: info.name.clone(), + params: parameters, + args: arguments, + partition_by, + order_by, + window_frame, + }) + } + fn analyze_identifier(&self, ident: &Ident, arguments: &mut Vec) -> Result<()> { let column_name = ident.clone().value; arguments.push(Expression::Column(column_name)); @@ -512,6 +605,7 @@ struct FunctionExprInfo { args_count: usize, kind: OperatorKind, parameters: Vec, + over: Option, } struct InListInfo { @@ -542,6 +636,7 @@ impl ExprRPNItem { args_count, kind: OperatorKind::Other, parameters: Vec::new(), + over: None, }) } @@ -552,6 +647,7 @@ impl ExprRPNItem { args_count: 2, kind: OperatorKind::Binary, parameters: Vec::new(), + over: None, }) } @@ -562,6 +658,7 @@ impl ExprRPNItem { args_count: 1, kind: OperatorKind::Unary, parameters: Vec::new(), + over: None, }) } } @@ -627,6 +724,7 @@ impl ExprRPNBuilder { args_count: function.args.len(), kind: OperatorKind::Other, parameters: function.params.to_owned(), + over: function.over.clone(), })); } Expr::Cast { diff --git a/query/src/sql/statements/analyzer_statement.rs b/query/src/sql/statements/analyzer_statement.rs index 824a77c28e22..90d9c217d2d7 100644 --- a/query/src/sql/statements/analyzer_statement.rs +++ b/query/src/sql/statements/analyzer_statement.rs @@ -55,6 +55,8 @@ pub struct QueryAnalyzeState { pub group_by_expressions: Vec, pub aggregate_expressions: Vec, + pub window_expressions: Vec, + // rebase on projection expressions without aliases, aggregate and group by expressions pub distinct_expressions: Vec, @@ -104,6 +106,10 @@ impl Debug for QueryAnalyzeState { debug_struct.field("aggregate", &self.aggregate_expressions); } + if !self.window_expressions.is_empty() { + debug_struct.field("window_func", &self.window_expressions); + } + if !self.expressions.is_empty() { match self.order_by_expressions.is_empty() { true => debug_struct.field("before_projection", &self.expressions), diff --git a/query/src/sql/statements/query/query_ast_ir.rs b/query/src/sql/statements/query/query_ast_ir.rs index 33e5b31b1d91..c206b1eb2e26 100644 --- a/query/src/sql/statements/query/query_ast_ir.rs +++ b/query/src/sql/statements/query/query_ast_ir.rs @@ -24,6 +24,7 @@ pub struct QueryASTIR { pub having_predicate: Option, pub group_by_expressions: Vec, pub aggregate_expressions: Vec, + pub window_expressions: Vec, pub projection_expressions: Vec, pub distinct: bool, pub order_by_expressions: Vec, @@ -44,6 +45,7 @@ pub trait QueryASTIRVisitor { Self::visit_group_by(&mut ir.group_by_expressions, data)?; Self::visit_order_by(&mut ir.order_by_expressions, data)?; Self::visit_aggregates(&mut ir.aggregate_expressions, data)?; + Self::visit_window(&mut ir.window_expressions, data)?; Self::visit_projection(&mut ir.projection_expressions, data)?; Ok(()) } @@ -74,6 +76,24 @@ pub trait QueryASTIRVisitor { Ok(()) } + Expression::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + for expr in args { + Self::visit_recursive_expr(expr, data)?; + } + for expr in partition_by { + Self::visit_recursive_expr(expr, data)?; + } + for expr in order_by { + Self::visit_recursive_expr(expr, data)?; + } + + Ok(()) + } Expression::Sort { expr, origin_expr, .. } => { @@ -116,6 +136,14 @@ pub trait QueryASTIRVisitor { Ok(()) } + fn visit_window(exprs: &mut Vec, data: &mut Data) -> Result<()> { + for expr in exprs { + Self::visit_recursive_expr(expr, data)?; + } + + Ok(()) + } + fn visit_order_by(exprs: &mut Vec, data: &mut Data) -> Result<()> { for expr in exprs { Self::visit_recursive_expr(expr, data)?; @@ -153,6 +181,10 @@ impl Debug for QueryASTIR { debug_struct.field("aggregate", &self.aggregate_expressions); } + if !self.window_expressions.is_empty() { + debug_struct.field("window", &self.window_expressions); + } + if self.distinct { debug_struct.field("distinct", &true); } diff --git a/query/src/sql/statements/query/query_normalizer.rs b/query/src/sql/statements/query/query_normalizer.rs index 4a7f1a4322ab..5807c1f4bb12 100644 --- a/query/src/sql/statements/query/query_normalizer.rs +++ b/query/src/sql/statements/query/query_normalizer.rs @@ -19,6 +19,7 @@ use common_exception::ErrorCode; use common_exception::Result; use common_planners::extract_aliases; use common_planners::find_aggregate_exprs_in_expr; +use common_planners::find_window_exprs_in_expr; use common_planners::resolve_aliases_to_exprs; use common_planners::Expression; use common_tracing::tracing; @@ -50,6 +51,7 @@ impl QueryNormalizer { having_predicate: None, group_by_expressions: vec![], aggregate_expressions: vec![], + window_expressions: vec![], distinct: false, order_by_expressions: vec![], projection_expressions: vec![], @@ -118,6 +120,7 @@ impl QueryNormalizer { for projection_expression in &projection_expressions { self.add_aggregate_function(projection_expression)?; + self.add_window_function(projection_expression)?; } self.query_ast_ir.projection_expressions = projection_expressions; @@ -246,4 +249,14 @@ impl QueryNormalizer { Ok(()) } + + fn add_window_function(&mut self, expr: &Expression) -> Result<()> { + for window_expr in find_window_exprs_in_expr(expr) { + if !self.query_ast_ir.window_expressions.contains(&window_expr) { + self.query_ast_ir.window_expressions.push(window_expr); + } + } + + Ok(()) + } } diff --git a/query/src/sql/statements/statement_select.rs b/query/src/sql/statements/statement_select.rs index 5c13d0223624..e26ff7e51940 100644 --- a/query/src/sql/statements/statement_select.rs +++ b/query/src/sql/statements/statement_select.rs @@ -24,6 +24,7 @@ use common_planners::find_aggregate_exprs_in_expr; use common_planners::rebase_expr; use common_planners::Expression; use common_tracing::tracing; +use common_tracing::tracing::log::debug; use sqlparser::ast::Expr; use sqlparser::ast::Offset; use sqlparser::ast::OrderByExpr; @@ -73,6 +74,8 @@ impl AnalyzableStatement for DfQueryStatement { QueryCollectPushDowns::collect_extras(&mut ir, &mut joined_schema, has_aggregation)?; let analyze_state = self.analyze_query(ir).await?; + tracing::debug!("analyze state is:\n{:?}", analyze_state); + self.check_and_finalize(joined_schema, analyze_state, ctx) .await } @@ -148,6 +151,10 @@ impl DfQueryStatement { Self::analyze_aggregate(&ir.aggregate_expressions, &mut analyze_state)?; } + if !ir.window_expressions.is_empty() { + Self::analyze_window(&ir.window_expressions, &mut analyze_state)?; + } + if ir.distinct { Self::analyze_distinct(&ir.projection_expressions, &mut analyze_state)?; } @@ -183,20 +190,70 @@ impl DfQueryStatement { _ => item.clone(), }; - // support select distinct aggr_func()... + let distinct_expr = rebase_expr(&distinct_expr, &state.expressions)?; let distinct_expr = rebase_expr(&distinct_expr, &state.group_by_expressions)?; let distinct_expr = rebase_expr(&distinct_expr, &state.aggregate_expressions)?; + let distinct_expr = rebase_expr(&distinct_expr, &state.window_expressions)?; state.distinct_expressions.push(distinct_expr); } Ok(()) } + fn analyze_window(window_exprs: &[Expression], state: &mut QueryAnalyzeState) -> Result<()> { + for expr in window_exprs { + match expr { + Expression::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + for arg in args { + state.add_expression(arg); + } + for partition_by_expr in partition_by { + state.add_expression(partition_by_expr); + } + for order_by_expr in order_by { + match order_by_expr { + Expression::Sort { expr, .. } => state.add_expression(expr), + _ => { + return Err(ErrorCode::LogicalError(format!( + "Found non-sort expression {:?} while analyzing order by expressions of window expressions", + order_by_expr + ))) + } + } + } + } + _ => { + return Err(ErrorCode::LogicalError(format!( + "Found non-window expression {:?} while analyzing window expressions!", + expr + ))) + } + } + } + + for expr in window_exprs { + let base_exprs = &state.expressions; + state + .window_expressions + .push(rebase_expr(expr, base_exprs)?); + } + + Ok(()) + } + fn analyze_projection(exprs: &[Expression], state: &mut QueryAnalyzeState) -> Result<()> { for item in exprs { - match item { - Expression::Alias(_, expr) => state.add_expression(expr), - _ => state.add_expression(item), + let expr = match item { + Expression::Alias(_, expr) => expr, + _ => item, + }; + if !matches!(expr, Expression::WindowFunction { .. }) { + state.add_expression(expr); } let rebased_expr = rebase_expr(item, &state.expressions)?; @@ -226,6 +283,10 @@ impl DfQueryStatement { ) -> Result { let dry_run_res = Self::verify_with_dry_run(&schema, &state)?; state.finalize_schema = dry_run_res.schema().clone(); + debug!( + "QueryAnalyzeState finalized schema:\n{}", + state.finalize_schema + ); let mut tables_desc = schema.take_tables_desc(); @@ -334,6 +395,28 @@ impl DfQueryStatement { } } + if !state.window_expressions.is_empty() { + let new_len = state.window_expressions.len() + state.expressions.len(); + let mut new_expression = Vec::with_capacity(new_len); + + for expr in &state.window_expressions { + new_expression.push(expr); + } + + for expr in &state.expressions { + new_expression.push(expr); + } + + match Self::dry_run_exprs_ref(&new_expression, &data_block) { + Ok(res) => { + data_block = res; + } + Err(cause) => { + return Err(cause.add_message_back(" (while in select window func)")); + } + } + } + if !state.order_by_expressions.is_empty() { if let Err(cause) = Self::dry_run_exprs(&state.order_by_expressions, &data_block) { return Err(cause.add_message_back(" (while in select order by)")); diff --git a/tests/suites/0_stateless/03_dml/03_0024_select_window_function.result b/tests/suites/0_stateless/03_dml/03_0024_select_window_function.result new file mode 100644 index 000000000000..42730501b478 --- /dev/null +++ b/tests/suites/0_stateless/03_dml/03_0024_select_window_function.result @@ -0,0 +1,304 @@ +================sep================ +China 2001 7845 +China 2001 7845 +Finland 2000 7845 +Finland 2000 7845 +Finland 2001 7845 +India 2000 7845 +India 2000 7845 +India 2000 7845 +USA 2000 7845 +USA 2000 7845 +USA 2001 7845 +USA 2001 7845 +USA 2001 7845 +USA 2001 7845 +USA 2001 7845 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1610 +Finland 2000 1610 +Finland 2001 1610 +India 2000 1350 +India 2000 1350 +India 2000 1350 +USA 2000 4575 +USA 2000 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1920 +Finland 2000 1920 +Finland 2001 1920 +India 2000 3270 +India 2000 3270 +India 2000 3270 +USA 2000 7845 +USA 2000 7845 +USA 2001 7845 +USA 2001 7845 +USA 2001 7845 +USA 2001 7845 +USA 2001 7845 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1600 +Finland 2000 1600 +Finland 2001 1610 +India 2000 1350 +India 2000 1350 +India 2000 1350 +USA 2000 1575 +USA 2000 1575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1600 +Finland 2000 1610 +Finland 2001 110 +India 2000 150 +India 2000 1350 +India 2000 1275 +USA 2000 1575 +USA 2000 1625 +USA 2001 3050 +USA 2001 2750 +USA 2001 2850 +USA 2001 1450 +USA 2001 250 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1600 +Finland 2000 1610 +Finland 2001 1610 +India 2000 150 +India 2000 1350 +India 2000 1350 +USA 2000 1575 +USA 2000 1625 +USA 2001 3125 +USA 2001 4325 +USA 2001 4475 +USA 2001 4575 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1610 +Finland 2000 1610 +Finland 2001 110 +India 2000 1350 +India 2000 1350 +India 2000 1275 +USA 2000 4575 +USA 2000 4575 +USA 2001 4500 +USA 2001 3000 +USA 2001 2950 +USA 2001 1450 +USA 2001 250 +================sep================ +China 2001 110 +China 2001 310 +Finland 2000 1500 +Finland 2000 1600 +Finland 2001 1610 +India 2000 75 +India 2000 150 +India 2000 1350 +USA 2000 75 +USA 2000 1575 +USA 2001 1625 +USA 2001 3125 +USA 2001 4325 +USA 2001 4475 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 200 +Finland 2000 1610 +Finland 2000 110 +Finland 2001 10 +India 2000 1350 +India 2000 1275 +India 2000 1200 +USA 2000 4575 +USA 2000 4500 +USA 2001 3000 +USA 2001 2950 +USA 2001 1450 +USA 2001 250 +USA 2001 100 +================sep================ +China 2001 310 +China 2001 310 +Finland 2000 1610 +Finland 2000 1610 +Finland 2001 1610 +India 2000 1350 +India 2000 1350 +India 2000 1350 +USA 2000 4575 +USA 2000 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 310 +Finland 2001 110 +Finland 2000 110 +Finland 2000 1500 +India 2000 150 +India 2000 150 +India 2000 1200 +USA 2001 375 +USA 2000 375 +USA 2001 375 +USA 2001 375 +USA 2001 4200 +USA 2000 4200 +USA 2001 4200 +================sep================ +China 2001 310 +China 2001 310 +Finland 2001 110 +Finland 2000 110 +Finland 2000 1610 +India 2000 150 +India 2000 150 +India 2000 1350 +USA 2001 375 +USA 2000 375 +USA 2001 375 +USA 2001 375 +USA 2001 4575 +USA 2000 4575 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 310 +Finland 2001 1610 +Finland 2000 1610 +Finland 2000 1500 +India 2000 1350 +India 2000 1350 +India 2000 1200 +USA 2001 4575 +USA 2000 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4200 +USA 2000 4200 +USA 2001 4200 +================sep================ +China 2001 310 +China 2001 200 +Finland 2001 110 +Finland 2000 100 +Finland 2000 1500 +India 2000 150 +India 2000 150 +India 2000 1200 +USA 2001 375 +USA 2000 325 +USA 2001 250 +USA 2001 150 +USA 2001 4200 +USA 2000 3000 +USA 2001 3000 +================sep================ +China 2001 110 +China 2001 310 +Finland 2001 10 +Finland 2000 110 +Finland 2000 1500 +India 2000 150 +India 2000 150 +India 2000 1200 +USA 2001 50 +USA 2000 125 +USA 2001 225 +USA 2001 375 +USA 2001 1200 +USA 2000 4200 +USA 2001 4200 +================sep================ +China 2001 110 +China 2001 310 +Finland 2001 10 +Finland 2000 110 +Finland 2000 1610 +India 2000 150 +India 2000 150 +India 2000 1350 +USA 2001 50 +USA 2000 125 +USA 2001 225 +USA 2001 375 +USA 2001 1575 +USA 2000 4575 +USA 2001 4575 +================sep================ +China 2001 310 +China 2001 200 +Finland 2001 1610 +Finland 2000 1600 +Finland 2000 1500 +India 2000 1350 +India 2000 1350 +India 2000 1200 +USA 2001 4575 +USA 2000 4525 +USA 2001 4450 +USA 2001 4350 +USA 2001 4200 +USA 2000 3000 +USA 2001 3000 +================sep================ +China 2001 310 +China 2001 310 +Finland 2001 1610 +Finland 2000 1610 +Finland 2000 1610 +India 2000 1350 +India 2000 1350 +India 2000 1350 +USA 2001 4575 +USA 2000 4575 +USA 2001 4575 +USA 2001 4575 +USA 2001 4575 +USA 2000 4575 +USA 2001 4575 +================sep================ +China 2001 310 155 +China 2001 310 155 +Finland 2001 110 55 +Finland 2000 110 55 +Finland 2000 1500 1500 +India 2000 150 75 +India 2000 150 75 +India 2000 1200 1200 +USA 2001 375 93.75 +USA 2000 375 93.75 +USA 2001 375 93.75 +USA 2001 375 93.75 +USA 2001 4200 1400 +USA 2000 4200 1400 +USA 2001 4200 1400 diff --git a/tests/suites/0_stateless/03_dml/03_0024_select_window_function.sql b/tests/suites/0_stateless/03_dml/03_0024_select_window_function.sql new file mode 100644 index 000000000000..d0174dd1d79a --- /dev/null +++ b/tests/suites/0_stateless/03_dml/03_0024_select_window_function.sql @@ -0,0 +1,56 @@ +DROP DATABASE IF EXISTS db1; +CREATE DATABASE db1; +USE db1; + +DROP TABLE IF EXISTS sales; +CREATE TABLE `sales` ( + `year` varchar(64) DEFAULT NULL, + `country` varchar(64) DEFAULT NULL, + `product` varchar(64) DEFAULT NULL, + `profit` int DEFAULT NULL +) Engine = Fuse; + +SET enable_new_processor_framework=0; + +INSERT INTO `sales` VALUES ('2000','Finland','Computer',1500),('2000','Finland','Phone',100),('2001','Finland','Phone',10),('2000','India','Calculator',75),('2000','India','Calculator',75),('2000','India','Computer',1200),('2000','USA','Calculator',75),('2000','USA','Computer',1500),('2001','USA','Calculator',50),('2001','USA','Computer',1500),('2001','USA','Computer',1200),('2001','USA','TV',150),('2001','USA','TV',100),('2001','China','TV',110),('2001','China','Computer',200); + +select '================sep================'; +select country, year, sum(profit) over() from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(order by country) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year rows between 1 preceding and 1 following) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year rows between unbounded preceding and 1 following) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year rows between 1 preceding and unbounded following) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year rows between unbounded preceding and current row) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year rows between current row and unbounded following) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by year rows between unbounded preceding and unbounded following) from sales order by country, year; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between 500 preceding and 500 following) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between unbounded preceding and 500 following) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between 500 preceding and unbounded following) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between current row and 500 following) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between 500 preceding and current row) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between unbounded preceding and current row) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between current row and unbounded following) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between unbounded preceding and unbounded following) from sales order by country, profit; +select '================sep================'; +select country, year, sum(profit) over(partition by country order by profit range between 500 preceding and 500 following) as sum, avg(profit) over(partition by country order by profit range between 500 preceding and 500 following) as avg from sales order by country, profit; + +DROP DATABASE db1; \ No newline at end of file