Skip to content

Commit

Permalink
[opt](nereids) optimize push limit to agg (#44042)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
Pr #34853 introduced PushTopnToAgg rule.
But there is a limitation that Topn(limit) should output all group by
keys.
This pr removes this limitation by using the first group by key as order
key.
  • Loading branch information
englefly authored and Your Name committed Nov 27, 2024
1 parent 78556da commit 2c484f6
Show file tree
Hide file tree
Showing 54 changed files with 1,679 additions and 1,589 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
Expand All @@ -32,6 +34,7 @@
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

/**
Expand All @@ -53,7 +56,11 @@ public List<Rule> buildRules() {
>= limit.getLimit() + limit.getOffset())
.then(limit -> {
LogicalAggregate<? extends Plan> agg = limit.child();
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
if (!orderKeysOpt.isPresent()) {
return null;
}
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg);
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
//limit->project->agg to topn->project->agg
Expand All @@ -62,12 +69,47 @@ public List<Rule> buildRules() {
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.when(limit -> outputAllGroupKeys(limit, limit.child().child()))
.then(limit -> {
LogicalProject<? extends Plan> project = limit.child();
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child();
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), project);
LogicalAggregate<? extends Plan> agg
= (LogicalAggregate<? extends Plan>) project.child();
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
if (!orderKeysOpt.isPresent()) {
return null;
}
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
Plan result;

if (outputAllGroupKeys(limit, agg)) {
result = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), project);
} else {
// add the first group by key to topn, and prune this key by upper project
// topn order keys are prefix of group by keys
// refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey()
Expression firstGroupByKey = agg.getGroupByExpressions().get(0);
if (!(firstGroupByKey instanceof SlotReference)) {
return null;
}
boolean shouldPruneFirstGroupByKey = true;
if (project.getOutputs().contains(firstGroupByKey)) {
shouldPruneFirstGroupByKey = false;
} else {
List<NamedExpression> bottomProjections = Lists.newArrayList(project.getProjects());
bottomProjections.add((SlotReference) firstGroupByKey);
project = project.withProjects(bottomProjections);
}
LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), project);
if (shouldPruneFirstGroupByKey) {
List<NamedExpression> limitOutput = limit.getOutput().stream()
.map(e -> (NamedExpression) e).collect(Collectors.toList());
result = new LogicalProject<>(limitOutput, topn);
} else {
result = topn;
}
}
return result;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
// topn -> agg: add all group key to sort key, if sort key is prefix of group key
logicalTopN(logicalAggregate())
Expand Down Expand Up @@ -111,9 +153,10 @@ private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) {
return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
}

private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
private Optional<OrderKey> tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate<? extends Plan> agg) {
if (agg.getGroupByExpressions().isEmpty()) {
return Optional.empty();
}
return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), true, false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ public String toString() {
"groupByExpr", groupByExpressions,
"outputExpr", outputExpressions,
"partitionExpr", partitionExpressions,
"requireProperties", requireProperties,
"topnOpt", topnPushInfo != null
"topnFilter", topnPushInfo != null,
"topnPushDown", getMutableState(MutableState.KEY_PUSH_TOPN_TO_AGG).isPresent()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ void testSortLimit() {
PlanChecker.from(connectContext).disableNereidsRules("PRUNE_EMPTY_PARTITION")
.analyze("select count(*) from (select * from student order by id) t limit 1")
.rewrite()
// there is no topn below agg
.matches(logicalTopN(logicalAggregate(logicalProject(logicalOlapScan()))));
.nonMatch(logicalTopN());
PlanChecker.from(connectContext)
.disableNereidsRules("PRUNE_EMPTY_PARTITION")
.analyze("select count(*) from (select * from student order by id limit 1) t")
Expand All @@ -184,8 +183,6 @@ void testSortLimit() {
.analyze("select count(*) from "
+ "(select * from student order by id) t1 left join student t2 on t1.id = t2.id limit 1")
.rewrite()
.matches(logicalTopN(logicalAggregate(logicalProject(logicalJoin(
logicalProject(logicalOlapScan()),
logicalProject(logicalOlapScan()))))));
.nonMatch(logicalTopN());
}
}
63 changes: 32 additions & 31 deletions regression-test/data/nereids_hint_tpcds_p0/shape/query23.out
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,36 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------filter(d_year IN (2000, 2001, 2002, 2003))
----------------------------------PhysicalOlapScan[date_dim]
----PhysicalResultSink
------PhysicalTopN[GATHER_SORT]
--------hashAgg[GLOBAL]
----------PhysicalDistribute[DistributionSpecGather]
------------hashAgg[LOCAL]
--------------PhysicalUnion
----------------PhysicalProject
------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF5 cs_item_sk->[item_sk]
--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
--------------------PhysicalProject
----------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF4 c_customer_sk->[cs_bill_customer_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3 RF4
----------------------------PhysicalProject
------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
--------------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
----------------PhysicalProject
------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF8 ws_item_sk->[item_sk]
--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8
--------------------PhysicalProject
----------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[ws_bill_customer_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF6 d_date_sk->[ws_sold_date_sk]
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[web_sales] apply RFs: RF6 RF7
----------------------------PhysicalProject
------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
--------------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
------PhysicalLimit[GLOBAL]
--------PhysicalLimit[LOCAL]
----------hashAgg[GLOBAL]
------------PhysicalDistribute[DistributionSpecGather]
--------------hashAgg[LOCAL]
----------------PhysicalUnion
------------------PhysicalProject
--------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF5 cs_item_sk->[item_sk]
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
----------------------PhysicalProject
------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF4 c_customer_sk->[cs_bill_customer_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3 RF4
------------------------------PhysicalProject
--------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
----------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
------------------PhysicalProject
--------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF8 ws_item_sk->[item_sk]
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8
----------------------PhysicalProject
------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[ws_bill_customer_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF6 d_date_sk->[ws_sold_date_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[web_sales] apply RFs: RF6 RF7
------------------------------PhysicalProject
--------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
----------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalCteConsumer ( cteId=CTEId#2 )

43 changes: 22 additions & 21 deletions regression-test/data/nereids_hint_tpcds_p0/shape/query32.out
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !ds_shape_32 --
PhysicalResultSink
--PhysicalTopN[GATHER_SORT]
----hashAgg[GLOBAL]
------PhysicalDistribute[DistributionSpecGather]
--------hashAgg[LOCAL]
----------PhysicalProject
------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
--------------PhysicalWindow
----------------PhysicalQuickSort[LOCAL_SORT]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((date_dim.d_date_sk = catalog_sales.cs_sold_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_sold_date_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((item.i_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[cs_item_sk]
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1
----------------------------PhysicalProject
------------------------------filter((item.i_manufact_id = 722))
--------------------------------PhysicalOlapScan[item]
------------------------PhysicalProject
--------------------------filter((date_dim.d_date <= '2001-06-07') and (date_dim.d_date >= '2001-03-09'))
----------------------------PhysicalOlapScan[date_dim]
--PhysicalLimit[GLOBAL]
----PhysicalLimit[LOCAL]
------hashAgg[GLOBAL]
--------PhysicalDistribute[DistributionSpecGather]
----------hashAgg[LOCAL]
------------PhysicalProject
--------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
----------------PhysicalWindow
------------------PhysicalQuickSort[LOCAL_SORT]
--------------------PhysicalDistribute[DistributionSpecHash]
----------------------PhysicalProject
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((date_dim.d_date_sk = catalog_sales.cs_sold_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_sold_date_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((item.i_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[cs_item_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1
------------------------------PhysicalProject
--------------------------------filter((item.i_manufact_id = 722))
----------------------------------PhysicalOlapScan[item]
--------------------------PhysicalProject
----------------------------filter((date_dim.d_date <= '2001-06-07') and (date_dim.d_date >= '2001-03-09'))
------------------------------PhysicalOlapScan[date_dim]

Hint log:
Used: leading(catalog_sales item date_dim )
Expand Down
Loading

0 comments on commit 2c484f6

Please sign in to comment.