diff --git a/src/query/sql/src/planner/optimizer/dynamic_sample/dynamic_sample.rs b/src/query/sql/src/planner/optimizer/dynamic_sample/dynamic_sample.rs index 5f5249299bc6..a2363d69029b 100644 --- a/src/query/sql/src/planner/optimizer/dynamic_sample/dynamic_sample.rs +++ b/src/query/sql/src/planner/optimizer/dynamic_sample/dynamic_sample.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; use std::sync::Arc; use std::time::Duration; @@ -25,8 +26,13 @@ use crate::optimizer::RelExpr; use crate::optimizer::SExpr; use crate::optimizer::StatInfo; use crate::planner::query_executor::QueryExecutor; +use crate::plans::Aggregate; +use crate::plans::AggregateMode; +use crate::plans::Limit; use crate::plans::Operator; +use crate::plans::ProjectSet; use crate::plans::RelOperator; +use crate::plans::UnionAll; use crate::MetadataRef; #[async_recursion::async_recursion(#[recursive::recursive])] @@ -78,11 +84,73 @@ pub async fn dynamic_sample( RelOperator::Join(_) => { join_selectivity_sample(ctx, metadata, s_expr, sample_executor).await } - RelOperator::Scan(_) => s_expr.plan().derive_stats(&RelExpr::with_s_expr(s_expr)), - // Todo: add more operators here, and support more query patterns. - _ => { - let rel_expr = RelExpr::with_s_expr(s_expr); - rel_expr.derive_cardinality() + RelOperator::Scan(_) + | RelOperator::DummyTableScan(_) + | RelOperator::CteScan(_) + | RelOperator::ConstantTableScan(_) + | RelOperator::CacheScan(_) + | RelOperator::ExpressionScan(_) + | RelOperator::RecursiveCteScan(_) + | RelOperator::Mutation(_) + | RelOperator::Recluster(_) + | RelOperator::CompactBlock(_) + | RelOperator::MutationSource(_) => { + s_expr.plan().derive_stats(&RelExpr::with_s_expr(s_expr)) + } + + RelOperator::Aggregate(agg) => { + let child_stat_info = + dynamic_sample(ctx, metadata, s_expr.child(0)?, sample_executor).await?; + if agg.mode == AggregateMode::Final { + return Ok(child_stat_info); + } + let agg = Aggregate::try_from(s_expr.plan().clone())?; + agg.derive_agg_stats(child_stat_info) + } + RelOperator::Limit(_) => { + let child_stat_info = + dynamic_sample(ctx, metadata, s_expr.child(0)?, sample_executor).await?; + let limit = Limit::try_from(s_expr.plan().clone())?; + limit.derive_limit_stats(child_stat_info) + } + RelOperator::UnionAll(_) => { + let left_stat_info = dynamic_sample( + ctx.clone(), + metadata.clone(), + s_expr.child(0)?, + sample_executor.clone(), + ) + .await?; + let right_stat_info = + dynamic_sample(ctx, metadata, s_expr.child(1)?, sample_executor).await?; + let union = UnionAll::try_from(s_expr.plan().clone())?; + union.derive_union_stats(left_stat_info, right_stat_info) + } + RelOperator::ProjectSet(_) => { + let mut child_stat_info = + dynamic_sample(ctx, metadata, s_expr.child(0)?, sample_executor) + .await? + .deref() + .clone(); + let project_set = ProjectSet::try_from(s_expr.plan().clone())?; + project_set.derive_project_set_stats(&mut child_stat_info) + } + RelOperator::MaterializedCte(_) => { + let right_stat_info = + dynamic_sample(ctx, metadata, s_expr.child(1)?, sample_executor).await?; + Ok(Arc::new(StatInfo { + cardinality: right_stat_info.cardinality, + statistics: right_stat_info.statistics.clone(), + })) + } + + RelOperator::EvalScalar(_) + | RelOperator::Sort(_) + | RelOperator::Exchange(_) + | RelOperator::Window(_) + | RelOperator::Udf(_) + | RelOperator::AsyncFunction(_) => { + dynamic_sample(ctx, metadata, s_expr.child(0)?, sample_executor).await } } } diff --git a/src/query/sql/src/planner/plans/aggregate.rs b/src/query/sql/src/planner/plans/aggregate.rs index 89b623aa19f3..51a4117a30a8 100644 --- a/src/query/sql/src/planner/plans/aggregate.rs +++ b/src/query/sql/src/planner/plans/aggregate.rs @@ -104,6 +104,59 @@ impl Aggregate { } Ok(col_set) } + + pub fn derive_agg_stats(&self, stat_info: Arc) -> Result> { + let (cardinality, mut statistics) = (stat_info.cardinality, stat_info.statistics.clone()); + let cardinality = if self.group_items.is_empty() { + // Scalar aggregation + 1.0 + } else if self + .group_items + .iter() + .any(|item| !statistics.column_stats.contains_key(&item.index)) + { + cardinality + } else { + // A upper bound + let res = self.group_items.iter().fold(1.0, |acc, item| { + let item_stat = statistics.column_stats.get(&item.index).unwrap(); + acc * item_stat.ndv + }); + for item in self.group_items.iter() { + let item_stat = statistics.column_stats.get_mut(&item.index).unwrap(); + if let Some(histogram) = &mut item_stat.histogram { + let mut num_values = 0.0; + let mut num_distinct = 0.0; + for bucket in histogram.buckets.iter() { + num_distinct += bucket.num_distinct(); + num_values += bucket.num_values(); + } + // When there is a high probability that eager aggregation + // is better, we will update the histogram. + if num_values / num_distinct >= 10.0 { + for bucket in histogram.buckets.iter_mut() { + bucket.aggregate_values(); + } + } + } + } + // To avoid res is very large + f64::min(res, cardinality) + }; + + let precise_cardinality = if self.group_items.is_empty() { + Some(1) + } else { + None + }; + Ok(Arc::new(StatInfo { + cardinality, + statistics: Statistics { + precise_cardinality, + column_stats: statistics.column_stats, + }, + })) + } } impl Operator for Aggregate { @@ -242,56 +295,7 @@ impl Operator for Aggregate { return rel_expr.derive_cardinality_child(0); } let stat_info = rel_expr.derive_cardinality_child(0)?; - let (cardinality, mut statistics) = (stat_info.cardinality, stat_info.statistics.clone()); - let cardinality = if self.group_items.is_empty() { - // Scalar aggregation - 1.0 - } else if self - .group_items - .iter() - .any(|item| !statistics.column_stats.contains_key(&item.index)) - { - cardinality - } else { - // A upper bound - let res = self.group_items.iter().fold(1.0, |acc, item| { - let item_stat = statistics.column_stats.get(&item.index).unwrap(); - acc * item_stat.ndv - }); - for item in self.group_items.iter() { - let item_stat = statistics.column_stats.get_mut(&item.index).unwrap(); - if let Some(histogram) = &mut item_stat.histogram { - let mut num_values = 0.0; - let mut num_distinct = 0.0; - for bucket in histogram.buckets.iter() { - num_distinct += bucket.num_distinct(); - num_values += bucket.num_values(); - } - // When there is a high probability that eager aggregation - // is better, we will update the histogram. - if num_values / num_distinct >= 10.0 { - for bucket in histogram.buckets.iter_mut() { - bucket.aggregate_values(); - } - } - } - } - // To avoid res is very large - f64::min(res, cardinality) - }; - - let precise_cardinality = if self.group_items.is_empty() { - Some(1) - } else { - None - }; - Ok(Arc::new(StatInfo { - cardinality, - statistics: Statistics { - precise_cardinality, - column_stats: statistics.column_stats, - }, - })) + self.derive_agg_stats(stat_info) } fn compute_required_prop_children( diff --git a/src/query/sql/src/planner/plans/limit.rs b/src/query/sql/src/planner/plans/limit.rs index 2efdecc19c7b..def22beedc34 100644 --- a/src/query/sql/src/planner/plans/limit.rs +++ b/src/query/sql/src/planner/plans/limit.rs @@ -33,6 +33,29 @@ pub struct Limit { pub offset: usize, } +impl Limit { + pub fn derive_limit_stats(&self, stat_info: Arc) -> Result> { + let cardinality = match self.limit { + Some(limit) if (limit as f64) < stat_info.cardinality => limit as f64, + _ => stat_info.cardinality, + }; + let precise_cardinality = match (self.limit, stat_info.statistics.precise_cardinality) { + (Some(limit), Some(pc)) => { + Some((pc.saturating_sub(self.offset as u64)).min(limit as u64)) + } + _ => None, + }; + + Ok(Arc::new(StatInfo { + cardinality, + statistics: Statistics { + precise_cardinality, + column_stats: Default::default(), + }, + })) + } +} + impl Operator for Limit { fn rel_op(&self) -> RelOp { RelOp::Limit @@ -67,23 +90,6 @@ impl Operator for Limit { fn derive_stats(&self, rel_expr: &RelExpr) -> Result> { let stat_info = rel_expr.derive_cardinality_child(0)?; - let cardinality = match self.limit { - Some(limit) if (limit as f64) < stat_info.cardinality => limit as f64, - _ => stat_info.cardinality, - }; - let precise_cardinality = match (self.limit, stat_info.statistics.precise_cardinality) { - (Some(limit), Some(pc)) => { - Some((pc.saturating_sub(self.offset as u64)).min(limit as u64)) - } - _ => None, - }; - - Ok(Arc::new(StatInfo { - cardinality, - statistics: Statistics { - precise_cardinality, - column_stats: Default::default(), - }, - })) + self.derive_limit_stats(stat_info) } } diff --git a/src/query/sql/src/planner/plans/project_set.rs b/src/query/sql/src/planner/plans/project_set.rs index b94dea408bf1..6ea32c812ac1 100644 --- a/src/query/sql/src/planner/plans/project_set.rs +++ b/src/query/sql/src/planner/plans/project_set.rs @@ -15,6 +15,8 @@ use std::ops::Deref; use std::sync::Arc; +use databend_common_exception::Result; + use crate::optimizer::RelExpr; use crate::optimizer::RelationalProperty; use crate::optimizer::StatInfo; @@ -30,6 +32,14 @@ pub struct ProjectSet { pub srfs: Vec, } +impl ProjectSet { + pub fn derive_project_set_stats(&self, input_stat: &mut StatInfo) -> Result> { + // ProjectSet is set-returning functions, precise_cardinality set None + input_stat.statistics.precise_cardinality = None; + Ok(Arc::new(input_stat.clone())) + } +} + impl Operator for ProjectSet { fn rel_op(&self) -> RelOp { RelOp::ProjectSet @@ -75,8 +85,6 @@ impl Operator for ProjectSet { fn derive_stats(&self, rel_expr: &RelExpr) -> databend_common_exception::Result> { let mut input_stat = rel_expr.derive_cardinality_child(0)?.deref().clone(); - // ProjectSet is set-returning functions, precise_cardinality set None - input_stat.statistics.precise_cardinality = None; - Ok(Arc::new(input_stat)) + self.derive_project_set_stats(&mut input_stat) } } diff --git a/src/query/sql/src/planner/plans/union_all.rs b/src/query/sql/src/planner/plans/union_all.rs index 4d49eee4107c..31e5828cdda4 100644 --- a/src/query/sql/src/planner/plans/union_all.rs +++ b/src/query/sql/src/planner/plans/union_all.rs @@ -54,6 +54,33 @@ impl UnionAll { } Ok(used_columns) } + + pub fn derive_union_stats( + &self, + left_stat_info: Arc, + right_stat_info: Arc, + ) -> Result> { + let cardinality = left_stat_info.cardinality + right_stat_info.cardinality; + + let precise_cardinality = + left_stat_info + .statistics + .precise_cardinality + .and_then(|left_cardinality| { + right_stat_info + .statistics + .precise_cardinality + .map(|right_cardinality| left_cardinality + right_cardinality) + }); + + Ok(Arc::new(StatInfo { + cardinality, + statistics: Statistics { + precise_cardinality, + column_stats: Default::default(), + }, + })) + } } impl Operator for UnionAll { @@ -117,26 +144,7 @@ impl Operator for UnionAll { fn derive_stats(&self, rel_expr: &RelExpr) -> Result> { let left_stat_info = rel_expr.derive_cardinality_child(0)?; let right_stat_info = rel_expr.derive_cardinality_child(1)?; - let cardinality = left_stat_info.cardinality + right_stat_info.cardinality; - - let precise_cardinality = - left_stat_info - .statistics - .precise_cardinality - .and_then(|left_cardinality| { - right_stat_info - .statistics - .precise_cardinality - .map(|right_cardinality| left_cardinality + right_cardinality) - }); - - Ok(Arc::new(StatInfo { - cardinality, - statistics: Statistics { - precise_cardinality, - column_stats: Default::default(), - }, - })) + self.derive_union_stats(left_stat_info, right_stat_info) } fn compute_required_prop_child( diff --git a/tests/sqllogictests/suites/tpch/sample.test b/tests/sqllogictests/suites/tpch/sample.test index 3d1fbafaabc4..b32f5d9cf653 100644 --- a/tests/sqllogictests/suites/tpch/sample.test +++ b/tests/sqllogictests/suites/tpch/sample.test @@ -399,6 +399,134 @@ select o_custkey from orders where not exists (select * from customer where subs 1 4 +query I +select + supp_nation, + cust_nation, + l_year, + truncate(sum(volume),3) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between to_date('1995-01-01') and to_date('1996-12-31') + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year; +---- +FRANCE GERMANY 1995 4637235.150 +FRANCE GERMANY 1996 5224779.573 +GERMANY FRANCE 1995 6232818.703 +GERMANY FRANCE 1996 5557312.112 + +query I +select + o_year, + truncate(sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume),8) as mkt_share +from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between to_date('1995-01-01') and to_date('1996-12-31') + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year; +---- +1995 0.02864874 +1996 0.01825027 + +query I +select + nation, + o_year, + truncate(truncate(sum(amount),0)/10, 0) as sum_profit +from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + truncate(l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity, 100) as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + sum_profit +limit 5; +---- +MOZAMBIQUE 1998 162042 +JORDAN 1998 181148 +MOROCCO 1998 181533 +JAPAN 1998 184953 +VIETNAM 1998 192431 + statement ok set random_function_seed = 0;