From a6474d87187e16c4934cfd1aed167eb1c95025ac Mon Sep 17 00:00:00 2001 From: Lukasz Stec Date: Tue, 30 Apr 2024 11:42:14 +0200 Subject: [PATCH] Gather partial TopN results Gathering TopN avoids unnecessary network overhead, especially when both the number of splits and the TopN limit are big. Co-authored-by: Kamil Endruszkiewicz --- .../io/trino/sql/planner/PlanOptimizers.java | 5 +- .../iterative/rule/GatherPartialTopN.java | 113 ++++++++++++++++++ .../iterative/rule/TestGatherPartialTopN.java | 94 +++++++++++++++ .../TestPartialTopNWithPresortedInput.java | 15 ++- .../plugin/jdbc/BaseJdbcConnectorTest.java | 7 +- 5 files changed, 225 insertions(+), 9 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherPartialTopN.java create mode 100644 core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestGatherPartialTopN.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index d34066da7d42..bc89ebb44088 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -52,6 +52,7 @@ import io.trino.sql.planner.iterative.rule.ExtractDereferencesFromFilterAboveScan; import io.trino.sql.planner.iterative.rule.ExtractSpatialJoins; import io.trino.sql.planner.iterative.rule.GatherAndMergeWindows; +import io.trino.sql.planner.iterative.rule.GatherPartialTopN; import io.trino.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import io.trino.sql.planner.iterative.rule.ImplementExceptAll; import io.trino.sql.planner.iterative.rule.ImplementExceptDistinctAsUnion; @@ -961,7 +962,9 @@ public PlanOptimizers( ruleStats, statsCalculator, costCalculator, - ImmutableSet.of(new UseNonPartitionedJoinLookupSource()))); + ImmutableSet.of( + new UseNonPartitionedJoinLookupSource(), + new GatherPartialTopN()))); // Optimizers above this do not need to care about aggregations with the type other than SINGLE // This optimizer must be run after all exchange-related optimizers diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherPartialTopN.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherPartialTopN.java new file mode 100644 index 000000000000..29fad33512fc --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/GatherPartialTopN.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.planner.Partitioning; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.TopNNode; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER; +import static io.trino.sql.planner.plan.ExchangeNode.gatheringExchange; +import static io.trino.sql.planner.plan.ExchangeNode.partitionedExchange; +import static io.trino.sql.planner.plan.Patterns.exchange; +import static io.trino.sql.planner.plan.Patterns.source; +import static io.trino.sql.planner.plan.Patterns.topN; +import static io.trino.sql.planner.plan.TopNNode.Step.PARTIAL; + +/** + * Adds local round-robin and gathering exchange on top of partial TopN to limit the task output size. + * Replaces plans like: + *
+ * exchange(remote)
+ *   - topn(partial)
+ * 
+ * with + *
+ *  exchange(remote)
+ *      - topn(partial)
+ *          - exchange(local, gather)
+ *              - topn(partial)
+ *                  - local_exchange(round_robin)
+ *                      - topn(partial)
+ * 
+ */ +public class GatherPartialTopN + implements Rule +{ + private static final Capture TOPN = newCapture(); + // the pattern filters for parent and source exchanges are added to avoid infinite recursion in iterative optimizer + private static final Pattern PATTERN = exchange() + .matching(GatherPartialTopN::isGatherRemoteExchange) + .with(source().matching( + topN().matching(topN -> topN.getStep().equals(PARTIAL)) + .with(source().matching(source -> !isGatherLocalExchange(source))) + .capturedAs(TOPN))); + + private static boolean isGatherLocalExchange(PlanNode source) + { + return source instanceof ExchangeNode exchange + && exchange.getScope().equals(LOCAL) + && exchange.getType().equals(GATHER); + } + + private static boolean isGatherRemoteExchange(ExchangeNode exchangeNode) + { + return exchangeNode.getScope().equals(REMOTE) + && exchangeNode.getType().equals(GATHER) + // non-empty orderingScheme means it's a merging exchange + && exchangeNode.getOrderingScheme().isEmpty(); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TopNNode originalPartialTopN = captures.get(TOPN); + + TopNNode roundRobinTopN = new TopNNode( + context.getIdAllocator().getNextId(), + partitionedExchange( + context.getIdAllocator().getNextId(), + LOCAL, + originalPartialTopN, + new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), originalPartialTopN.getOutputSymbols())), + originalPartialTopN.getCount(), + originalPartialTopN.getOrderingScheme(), + PARTIAL); + + return Result.ofPlanNode(node.replaceChildren( + ImmutableList.of(new TopNNode( + context.getIdAllocator().getNextId(), + gatheringExchange(context.getIdAllocator().getNextId(), LOCAL, roundRobinTopN), + originalPartialTopN.getCount(), + originalPartialTopN.getOrderingScheme(), + PARTIAL)))); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestGatherPartialTopN.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestGatherPartialTopN.java new file mode 100644 index 000000000000..6b7a69280e7e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestGatherPartialTopN.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import org.junit.jupiter.api.Test; + +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; +import static io.trino.sql.planner.assertions.PlanMatchPattern.topN; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.trino.sql.planner.plan.TopNNode.Step.PARTIAL; +import static io.trino.sql.tree.SortItem.NullOrdering.FIRST; +import static io.trino.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestGatherPartialTopN + extends BaseRuleTest +{ + @Test + public void testPartialTopNGather() + { + tester().assertThat(new GatherPartialTopN()) + .on(p -> + { + Symbol orderBy = p.symbol("a"); + return p.exchange(exchange -> exchange + .scope(REMOTE) + .singleDistributionPartitioningScheme(orderBy) + .addInputsSet(orderBy) + .addSource(p.topN(10, ImmutableList.of(orderBy), PARTIAL, p.values(orderBy)))); + }).matches(exchange(REMOTE, + topN( + 10, + ImmutableList.of(sort("a", ASCENDING, FIRST)), + PARTIAL, + exchange(LOCAL, GATHER, + topN( + 10, + ImmutableList.of(sort("a", ASCENDING, FIRST)), + PARTIAL, + exchange(LOCAL, REPARTITION, FIXED_ARBITRARY_DISTRIBUTION, + topN( + 10, + ImmutableList.of(sort("a", ASCENDING, FIRST)), + PARTIAL, + values("a")))))))); + } + + @Test + public void testRuleDoesNotFireTwice() + { + tester().assertThat(new GatherPartialTopN()) + .on(p -> + { + Symbol orderBy = p.symbol("a"); + return p.exchange(exchange -> exchange + .scope(REMOTE) + .singleDistributionPartitioningScheme(orderBy) + .addInputsSet(orderBy) + .addSource(p.topN( + 10, + ImmutableList.of(orderBy), + PARTIAL, + p.exchange(localExchange -> localExchange + .scope(LOCAL) + .type(GATHER) + .singleDistributionPartitioningScheme(orderBy) + .addInputsSet(orderBy) + .addSource(p.topN( + 10, + ImmutableList.of(orderBy), + PARTIAL, + p.values(orderBy))))))); + }).doesNotFire(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java index 6bcffa256143..f46888d61552 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestPartialTopNWithPresortedInput.java @@ -45,6 +45,7 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; @@ -147,11 +148,15 @@ public void testWithSortedTable() orderBy = ImmutableList.of(sort("t_col_a", ASCENDING, LAST)); assertDistributedPlan("SELECT col_a FROM table_a ORDER BY 1 ASC NULLS LAST LIMIT 10", output( - topN(10, orderBy, FINAL, - exchange(LOCAL, GATHER, ImmutableList.of(), - exchange(REMOTE, GATHER, ImmutableList.of(), - topN(10, orderBy, PARTIAL, - tableScan("table_a", ImmutableMap.of("t_col_a", "col_a")))))))); + topN(10, orderBy, FINAL, + exchange(LOCAL, GATHER, ImmutableList.of(), + exchange(REMOTE, GATHER, ImmutableList.of(), + topN(10, orderBy, PARTIAL, + exchange(LOCAL, GATHER, + topN(10, orderBy, PARTIAL, + exchange(LOCAL, REPARTITION, FIXED_ARBITRARY_DISTRIBUTION, + topN(10, orderBy, PARTIAL, + tableScan("table_a", ImmutableMap.of("t_col_a", "col_a")))))))))))); } @Test diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 42ef84c3619d..14a4acd5908b 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -76,6 +76,7 @@ import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; @@ -251,7 +252,7 @@ public void testAggregationPushdown() getSession(), "SELECT custkey, sum(totalprice) FROM (SELECT custkey, totalprice FROM orders ORDER BY orderdate ASC, totalprice ASC LIMIT 10) GROUP BY custkey", hasBehavior(SUPPORTS_TOPN_PUSHDOWN), - node(TopNNode.class, anyTree(node(TableScanNode.class)))); + project(node(TopNNode.class, anyTree(node(TableScanNode.class))))); // GROUP BY with JOIN assertConditionallyPushedDown( joinPushdownEnabled(getSession()), @@ -908,7 +909,7 @@ public void testLimitPushdown() } // with TopN over numeric column - PlanMatchPattern topnOverTableScan = node(TopNNode.class, anyTree(node(TableScanNode.class))); + PlanMatchPattern topnOverTableScan = project(node(TopNNode.class, anyTree(node(TableScanNode.class)))); assertConditionallyPushedDown( getSession(), "SELECT * FROM (SELECT regionkey FROM nation ORDER BY nationkey ASC LIMIT 10) LIMIT 5", @@ -1099,7 +1100,7 @@ public void testCaseSensitiveTopNPushdown() // topN over varchar/char columns should only be pushed down if the remote systems's sort order matches Trino boolean expectTopNPushdown = hasBehavior(SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR); - PlanMatchPattern topNOverTableScan = node(TopNNode.class, anyTree(node(TableScanNode.class))); + PlanMatchPattern topNOverTableScan = project(node(TopNNode.class, anyTree(node(TableScanNode.class)))); try (TestTable testTable = new TestTable( getQueryRunner()::execute,