Skip to content

Commit

Permalink
push limit to local agg
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed May 14, 2024
1 parent 29df2ba commit a9fa9a9
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,23 @@ public PlanFragment visitPhysicalHashAggregate(
// local exchanger will be used.
aggregationNode.setColocate(true);
}
if (aggregate.getTopn() != null) {
List<Expr> orderingExprs = Lists.newArrayList();
List<Boolean> ascOrders = Lists.newArrayList();
List<Boolean> nullsFirstParams = Lists.newArrayList();
aggregate.getTopn().orderkeys.forEach(k -> {
orderingExprs.add(ExpressionTranslator.translate(k.getExpr(), context));
ascOrders.add(k.isAsc());
nullsFirstParams.add(k.isNullFirst());
});
SortInfo sortInfo = new SortInfo(orderingExprs, ascOrders, nullsFirstParams, outputTupleDesc);
aggregationNode.setSortByGroupKey(sortInfo);
if (aggregationNode.getLimit() == -1) {
aggregationNode.setLimit(aggregate.getTopn().limit);
}
} else {
aggregationNode.setSortByGroupKey(null);
}
setPlanRoot(inputPlanFragment, aggregationNode, aggregate);
if (aggregate.getStats() != null) {
aggregationNode.setCardinality((long) aggregate.getStats().getRowCount());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,100 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/AggregationNode.java
// and modified by Doris

package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopNOptInfo;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;

import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;

/**
Pattern1:
limit(n) -> aggGlobal -> distribute -> aggLocal
=>
limit(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
public class PushLimitToLocalAgg extends PlanPostProcessor{
Pattern2: topn orderkeys are the same as group keys
topn -> aggGlobal -> distribute -> aggLocal
=>
topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
*/
public class PushLimitToLocalAgg extends PlanPostProcessor {
@Override
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
Plan newChild = this.visit(topN.child(), ctx);
if (newChild != topN.child()) {
topN = (PhysicalTopN<? extends Plan>) topN.withChildren(Lists.newArrayList(newChild))
.copyStatsAndGroupIdFrom(topN);
}
if (topN.child() instanceof PhysicalHashAggregate) {
PhysicalHashAggregate agg1 = (PhysicalHashAggregate) topN.child();
if (agg1.getAggPhase().isGlobal()) {
if (agg1.child(0) instanceof PhysicalDistribute
&& agg1.child(0).child(0) instanceof PhysicalHashAggregate) {
PhysicalDistribute dist = (PhysicalDistribute) agg1.child(0);
PhysicalHashAggregate agg2 = (PhysicalHashAggregate) agg1.child(0).child(0);
PhysicalTopN topnForLocalAgg = (PhysicalTopN) topN.withChildren(agg2);
dist = (PhysicalDistribute) dist.withChildren(topnForLocalAgg);
agg1 = (PhysicalHashAggregate) agg1.withChildren(dist);
topN = (PhysicalTopN<? extends Plan>) topN.withChildren(agg1);
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topN.child();
upperAgg.setTopn(new TopNOptInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
if (upperAgg.getAggPhase().isGlobal()) {
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
bottomAgg.setTopn(new TopNOptInfo(
topN.getOrderKeys(),
topN.getLimit() + topN.getOffset()));
bottomAgg.child().accept(this, ctx);
}
}
} else {
topN.child().accept(this, ctx);
}
return topN;
}

@Override
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
if (limit.child() instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limit.child();
upperAgg.setTopn(new TopNOptInfo(
generateOrderKeysByGroupKeys(upperAgg),
limit.getLimit() + limit.getOffset()));
if (upperAgg.getAggPhase().isGlobal()) {
if (upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
bottomAgg.setTopn(new TopNOptInfo(
generateOrderKeysByGroupKeys(bottomAgg),
limit.getLimit() + limit.getOffset()));
bottomAgg.child().accept(this, ctx);
}
}
} else {
limit.child().accept(this, ctx);
}
return limit;
}

private List<OrderKey> generateOrderKeysByGroupKeys(PhysicalHashAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequireProperties;
import org.apache.doris.nereids.properties.RequirePropertiesSupplier;
Expand Down Expand Up @@ -60,6 +61,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar

private final RequireProperties requireProperties;

// only used in post processor
private TopNOptInfo topn = null;

public PhysicalHashAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
AggregateParam aggregateParam, boolean maybeUsingStream, LogicalProperties logicalProperties,
RequireProperties requireProperties, CHILD_TYPE child) {
Expand Down Expand Up @@ -196,6 +200,7 @@ public String toString() {
"outputExpr", outputExpressions,
"partitionExpr", partitionExpressions,
"requireProperties", requireProperties,
"topnOpt", topn != null,
"stats", statistics
);
}
Expand Down Expand Up @@ -231,19 +236,22 @@ public PhysicalHashAggregate<Plan> withChildren(List<Plan> children) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
requireProperties, physicalProperties, statistics,
children.get(0));
children.get(0))
.setTopn(topn);
}

public PhysicalHashAggregate<CHILD_TYPE> withPartitionExpressions(List<Expression> partitionExpressions) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions,
Optional.ofNullable(partitionExpressions), aggregateParam, maybeUsingStream,
Optional.empty(), getLogicalProperties(), requireProperties, child());
Optional.empty(), getLogicalProperties(), requireProperties, child())
.setTopn(topn);
}

@Override
public PhysicalHashAggregate<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child());
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child())
.setTopn(topn);
}

@Override
Expand All @@ -252,7 +260,7 @@ public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpr
Preconditions.checkArgument(children.size() == 1);
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, logicalProperties.get(),
requireProperties, children.get(0));
requireProperties, children.get(0)).setTopn(topn);
}

@Override
Expand All @@ -261,21 +269,21 @@ public PhysicalHashAggregate<CHILD_TYPE> withPhysicalPropertiesAndStats(Physical
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
requireProperties, physicalProperties, statistics,
child());
child()).setTopn(topn);
}

@Override
public PhysicalHashAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) {
return new PhysicalHashAggregate<>(groupByExpressions, newOutput, partitionExpressions,
aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(),
requireProperties, physicalProperties, statistics, child());
requireProperties, physicalProperties, statistics, child()).setTopn(topn);
}

public <C extends Plan> PhysicalHashAggregate<C> withRequirePropertiesAndChild(
RequireProperties requireProperties, C newChild) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(),
requireProperties, physicalProperties, statistics, newChild);
requireProperties, physicalProperties, statistics, newChild).setTopn(topn);
}

@Override
Expand All @@ -299,4 +307,26 @@ public PhysicalHashAggregate<CHILD_TYPE> resetLogicalProperties() {
requireProperties, physicalProperties, statistics,
child());
}

/**
* used to push limit down to localAgg
*/
public static class TopNOptInfo {
public List<OrderKey> orderkeys;
public long limit;

public TopNOptInfo(List<OrderKey> orderkeys, long limit) {
this.orderkeys = ImmutableList.copyOf(orderkeys);
this.limit = limit;
}
}

public TopNOptInfo getTopn() {
return topn;
}

public PhysicalHashAggregate<CHILD_TYPE> setTopn(TopNOptInfo topn) {
this.topn = topn;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.UserException;
Expand Down Expand Up @@ -65,6 +66,8 @@ public class AggregationNode extends PlanNode {
// If true, use streaming preaggregation algorithm. Not valid if this is a merge agg.
private boolean useStreamingPreagg;

private SortInfo sortByGroupKey;

/**
* Create an agg node that is not an intermediate node.
* isIntermediate is true if it is a slave node in a 2-part agg plan.
Expand Down Expand Up @@ -288,6 +291,9 @@ protected void toThrift(TPlanNode msg) {
msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg);
msg.agg_node.setIsFirstPhase(aggInfo.isFirstPhase());
msg.agg_node.setIsColocate(isColocate);
if (sortByGroupKey != null) {
msg.agg_node.setAggSortInfoByGroupKey(sortByGroupKey.toThrift());
}
List<Expr> groupingExprs = aggInfo.getGroupingExprs();
if (groupingExprs != null) {
msg.agg_node.setGroupingExprs(Expr.treesToThrift(groupingExprs));
Expand Down Expand Up @@ -333,6 +339,7 @@ public String getNodeExplainString(String detailPrefix, TExplainLevel detailLeve
if (!conjuncts.isEmpty()) {
output.append(detailPrefix).append("having: ").append(getExplainString(conjuncts)).append("\n");
}
output.append(detailPrefix).append("sortByGroupKey:").append(sortByGroupKey != null).append("\n");
output.append(detailPrefix).append(String.format(
"cardinality=%,d", cardinality)).append("\n");
return output.toString();
Expand Down Expand Up @@ -411,4 +418,13 @@ public void finalize(Analyzer analyzer) throws UserException {
public void setColocate(boolean colocate) {
isColocate = colocate;
}


public boolean isSortByGroupKey() {
return sortByGroupKey != null;
}

public void setSortByGroupKey(SortInfo sortByGroupKey) {
this.sortByGroupKey = sortByGroupKey;
}
}
2 changes: 1 addition & 1 deletion gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ struct TAggregationNode {
7: optional list<TSortInfo> agg_sort_infos
8: optional bool is_first_phase
9: optional bool is_colocate
// 9: optional bool use_fixed_length_serialization_opt
10: optional TSortInfo agg_sort_info_by_group_key
}

struct TRepeatNode {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("push_limit_to_local_agg") {
String db = context.config.getDbNameByFile(new File(context.file.parent))
sql "use ${db}"
explain{
sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey limit 4;"
multiContains ("sortByGroupKey:true", 2)
}
}

0 comments on commit a9fa9a9

Please sign in to comment.