Skip to content

Commit

Permalink
Gather partial TopN results
Browse files Browse the repository at this point in the history
Gathering TopN avoids unnecessary network overhead, especially when
both the number of splits and the TopN limit are big.

Co-authored-by: Kamil Endruszkiewicz <kamil.endruszkiewicz@starburstdata.com>
  • Loading branch information
2 people authored and raunaqmorarka committed May 2, 2024
1 parent 06505f2 commit a6474d8
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.trino.sql.planner.iterative.rule.ExtractDereferencesFromFilterAboveScan;
import io.trino.sql.planner.iterative.rule.ExtractSpatialJoins;
import io.trino.sql.planner.iterative.rule.GatherAndMergeWindows;
import io.trino.sql.planner.iterative.rule.GatherPartialTopN;
import io.trino.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter;
import io.trino.sql.planner.iterative.rule.ImplementExceptAll;
import io.trino.sql.planner.iterative.rule.ImplementExceptDistinctAsUnion;
Expand Down Expand Up @@ -961,7 +962,9 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(new UseNonPartitionedJoinLookupSource())));
ImmutableSet.of(
new UseNonPartitionedJoinLookupSource(),
new GatherPartialTopN())));

// Optimizers above this do not need to care about aggregations with the type other than SINGLE
// This optimizer must be run after all exchange-related optimizers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TopNNode;

import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.trino.sql.planner.plan.ExchangeNode.gatheringExchange;
import static io.trino.sql.planner.plan.ExchangeNode.partitionedExchange;
import static io.trino.sql.planner.plan.Patterns.exchange;
import static io.trino.sql.planner.plan.Patterns.source;
import static io.trino.sql.planner.plan.Patterns.topN;
import static io.trino.sql.planner.plan.TopNNode.Step.PARTIAL;

/**
* Adds local round-robin and gathering exchange on top of partial TopN to limit the task output size.
* Replaces plans like:
* <pre>
* exchange(remote)
* - topn(partial)
* </pre>
* with
* <pre>
* exchange(remote)
* - topn(partial)
* - exchange(local, gather)
* - topn(partial)
* - local_exchange(round_robin)
* - topn(partial)
* </pre>
*/
public class GatherPartialTopN
implements Rule<ExchangeNode>
{
private static final Capture<TopNNode> TOPN = newCapture();
// the pattern filters for parent and source exchanges are added to avoid infinite recursion in iterative optimizer
private static final Pattern<ExchangeNode> PATTERN = exchange()
.matching(GatherPartialTopN::isGatherRemoteExchange)
.with(source().matching(
topN().matching(topN -> topN.getStep().equals(PARTIAL))
.with(source().matching(source -> !isGatherLocalExchange(source)))
.capturedAs(TOPN)));

private static boolean isGatherLocalExchange(PlanNode source)
{
return source instanceof ExchangeNode exchange
&& exchange.getScope().equals(LOCAL)
&& exchange.getType().equals(GATHER);
}

private static boolean isGatherRemoteExchange(ExchangeNode exchangeNode)
{
return exchangeNode.getScope().equals(REMOTE)
&& exchangeNode.getType().equals(GATHER)
// non-empty orderingScheme means it's a merging exchange
&& exchangeNode.getOrderingScheme().isEmpty();
}

@Override
public Pattern<ExchangeNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(ExchangeNode node, Captures captures, Context context)
{
TopNNode originalPartialTopN = captures.get(TOPN);

TopNNode roundRobinTopN = new TopNNode(
context.getIdAllocator().getNextId(),
partitionedExchange(
context.getIdAllocator().getNextId(),
LOCAL,
originalPartialTopN,
new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), originalPartialTopN.getOutputSymbols())),
originalPartialTopN.getCount(),
originalPartialTopN.getOrderingScheme(),
PARTIAL);

return Result.ofPlanNode(node.replaceChildren(
ImmutableList.of(new TopNNode(
context.getIdAllocator().getNextId(),
gatheringExchange(context.getIdAllocator().getNextId(), LOCAL, roundRobinTopN),
originalPartialTopN.getCount(),
originalPartialTopN.getOrderingScheme(),
PARTIAL))));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import org.junit.jupiter.api.Test;

import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.sort;
import static io.trino.sql.planner.assertions.PlanMatchPattern.topN;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.trino.sql.planner.plan.TopNNode.Step.PARTIAL;
import static io.trino.sql.tree.SortItem.NullOrdering.FIRST;
import static io.trino.sql.tree.SortItem.Ordering.ASCENDING;

public class TestGatherPartialTopN
extends BaseRuleTest
{
@Test
public void testPartialTopNGather()
{
tester().assertThat(new GatherPartialTopN())
.on(p ->
{
Symbol orderBy = p.symbol("a");
return p.exchange(exchange -> exchange
.scope(REMOTE)
.singleDistributionPartitioningScheme(orderBy)
.addInputsSet(orderBy)
.addSource(p.topN(10, ImmutableList.of(orderBy), PARTIAL, p.values(orderBy))));
}).matches(exchange(REMOTE,
topN(
10,
ImmutableList.of(sort("a", ASCENDING, FIRST)),
PARTIAL,
exchange(LOCAL, GATHER,
topN(
10,
ImmutableList.of(sort("a", ASCENDING, FIRST)),
PARTIAL,
exchange(LOCAL, REPARTITION, FIXED_ARBITRARY_DISTRIBUTION,
topN(
10,
ImmutableList.of(sort("a", ASCENDING, FIRST)),
PARTIAL,
values("a"))))))));
}

@Test
public void testRuleDoesNotFireTwice()
{
tester().assertThat(new GatherPartialTopN())
.on(p ->
{
Symbol orderBy = p.symbol("a");
return p.exchange(exchange -> exchange
.scope(REMOTE)
.singleDistributionPartitioningScheme(orderBy)
.addInputsSet(orderBy)
.addSource(p.topN(
10,
ImmutableList.of(orderBy),
PARTIAL,
p.exchange(localExchange -> localExchange
.scope(LOCAL)
.type(GATHER)
.singleDistributionPartitioningScheme(orderBy)
.addInputsSet(orderBy)
.addSource(p.topN(
10,
ImmutableList.of(orderBy),
PARTIAL,
p.values(orderBy)))))));
}).doesNotFire();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.expression;
Expand Down Expand Up @@ -147,11 +148,15 @@ public void testWithSortedTable()

orderBy = ImmutableList.of(sort("t_col_a", ASCENDING, LAST));
assertDistributedPlan("SELECT col_a FROM table_a ORDER BY 1 ASC NULLS LAST LIMIT 10", output(
topN(10, orderBy, FINAL,
exchange(LOCAL, GATHER, ImmutableList.of(),
exchange(REMOTE, GATHER, ImmutableList.of(),
topN(10, orderBy, PARTIAL,
tableScan("table_a", ImmutableMap.of("t_col_a", "col_a"))))))));
topN(10, orderBy, FINAL,
exchange(LOCAL, GATHER, ImmutableList.of(),
exchange(REMOTE, GATHER, ImmutableList.of(),
topN(10, orderBy, PARTIAL,
exchange(LOCAL, GATHER,
topN(10, orderBy, PARTIAL,
exchange(LOCAL, REPARTITION, FIXED_ARBITRARY_DISTRIBUTION,
topN(10, orderBy, PARTIAL,
tableScan("table_a", ImmutableMap.of("t_col_a", "col_a"))))))))))));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN;
Expand Down Expand Up @@ -251,7 +252,7 @@ public void testAggregationPushdown()
getSession(),
"SELECT custkey, sum(totalprice) FROM (SELECT custkey, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY custkey",
hasBehavior(SUPPORTS_TOPN_PUSHDOWN),
node(TopNNode.class, anyTree(node(TableScanNode.class))));
project(node(TopNNode.class, anyTree(node(TableScanNode.class)))));
// GROUP BY with JOIN
assertConditionallyPushedDown(
joinPushdownEnabled(getSession()),
Expand Down Expand Up @@ -908,7 +909,7 @@ public void testLimitPushdown()
}

// with TopN over numeric column
PlanMatchPattern topnOverTableScan = node(TopNNode.class, anyTree(node(TableScanNode.class)));
PlanMatchPattern topnOverTableScan = project(node(TopNNode.class, anyTree(node(TableScanNode.class))));
assertConditionallyPushedDown(
getSession(),
"SELECT * FROM (SELECT regionkey FROM nation ORDER BY nationkey ASC LIMIT 10) LIMIT 5",
Expand Down Expand Up @@ -1099,7 +1100,7 @@ public void testCaseSensitiveTopNPushdown()

// topN over varchar/char columns should only be pushed down if the remote systems's sort order matches Trino
boolean expectTopNPushdown = hasBehavior(SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR);
PlanMatchPattern topNOverTableScan = node(TopNNode.class, anyTree(node(TableScanNode.class)));
PlanMatchPattern topNOverTableScan = project(node(TopNNode.class, anyTree(node(TableScanNode.class))));

try (TestTable testTable = new TestTable(
getQueryRunner()::execute,
Expand Down

0 comments on commit a6474d8

Please sign in to comment.