Skip to content

Commit

Permalink
Create dedicated aggregation function matcher
Browse files Browse the repository at this point in the history
In IR, aggregation functions are not represented with FunctionCall.
  • Loading branch information
martint committed Mar 4, 2024
1 parent fd4cffc commit 21d29bd
Show file tree
Hide file tree
Showing 31 changed files with 369 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED;
import static io.trino.sql.planner.assertions.PlanMatchPattern.DynamicFilterPattern;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aliasToIndex;
import static io.trino.sql.planner.assertions.PlanMatchPattern.any;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyNot;
Expand Down Expand Up @@ -226,25 +227,25 @@ public void testAggregation()
assertDistributedPlan("SELECT orderstatus, sum(totalprice) FROM orders GROUP BY orderstatus",
anyTree(
aggregation(
ImmutableMap.of("final_sum", functionCall("sum", ImmutableList.of("partial_sum"))),
ImmutableMap.of("final_sum", aggregationFunction("sum", ImmutableList.of("partial_sum"))),
FINAL,
exchange(LOCAL, GATHER,
exchange(REMOTE, REPARTITION,
aggregation(
ImmutableMap.of("partial_sum", functionCall("sum", ImmutableList.of("totalprice"))),
ImmutableMap.of("partial_sum", aggregationFunction("sum", ImmutableList.of("totalprice"))),
PARTIAL,
tableScan("orders", ImmutableMap.of("totalprice", "totalprice"))))))));

// simple group by over filter that keeps at most one group
assertDistributedPlan("SELECT orderstatus, sum(totalprice) FROM orders WHERE orderstatus='O' GROUP BY orderstatus",
anyTree(
aggregation(
ImmutableMap.of("final_sum", functionCall("sum", ImmutableList.of("partial_sum"))),
ImmutableMap.of("final_sum", aggregationFunction("sum", ImmutableList.of("partial_sum"))),
FINAL,
exchange(LOCAL, GATHER,
exchange(REMOTE, REPARTITION,
aggregation(
ImmutableMap.of("partial_sum", functionCall("sum", ImmutableList.of("totalprice"))),
ImmutableMap.of("partial_sum", aggregationFunction("sum", ImmutableList.of("totalprice"))),
PARTIAL,
tableScan("orders", ImmutableMap.of("totalprice", "totalprice"))))))));
}
Expand All @@ -262,14 +263,14 @@ public void testAllFieldsDereferenceOnSubquery()
ImmutableMap.of("row", expression("ROW(min, max)")),
aggregation(
ImmutableMap.of(
"min", functionCall("min", ImmutableList.of("min_regionkey")),
"max", functionCall("max", ImmutableList.of("max_name"))),
"min", aggregationFunction("min", ImmutableList.of("min_regionkey")),
"max", aggregationFunction("max", ImmutableList.of("max_name"))),
FINAL,
any(
aggregation(
ImmutableMap.of(
"min_regionkey", functionCall("min", ImmutableList.of("REGIONKEY")),
"max_name", functionCall("max", ImmutableList.of("NAME"))),
"min_regionkey", aggregationFunction("min", ImmutableList.of("REGIONKEY")),
"max_name", aggregationFunction("max", ImmutableList.of("NAME"))),
PARTIAL,
tableScan("nation", ImmutableMap.of("NAME", "name", "REGIONKEY", "regionkey")))))))));
}
Expand Down Expand Up @@ -863,7 +864,7 @@ public void testStreamingAggregationForCorrelatedSubquery()
anyTree(
aggregation(
singleGroupingSet("n_name", "n_regionkey", "unique"),
ImmutableMap.of(Optional.of("max"), functionCall("max", ImmutableList.of("r_name"))),
ImmutableMap.of(Optional.of("max"), aggregationFunction("max", ImmutableList.of("r_name"))),
ImmutableList.of("n_name", "n_regionkey", "unique"),
ImmutableList.of("non_null"),
Optional.empty(),
Expand All @@ -883,7 +884,7 @@ public void testStreamingAggregationForCorrelatedSubquery()
anyTree(
aggregation(
singleGroupingSet("n_name", "n_regionkey", "unique"),
ImmutableMap.of(Optional.of("max"), functionCall("max", ImmutableList.of("r_name"))),
ImmutableMap.of(Optional.of("max"), aggregationFunction("max", ImmutableList.of("r_name"))),
ImmutableList.of("n_name", "n_regionkey", "unique"),
ImmutableList.of("non_null"),
Optional.empty(),
Expand All @@ -909,7 +910,7 @@ public void testStreamingAggregationOverJoin()
anyTree(
aggregation(
singleGroupingSet("l_orderkey"),
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableMap.of(Optional.empty(), aggregationFunction("count", ImmutableList.of())),
ImmutableList.of("l_orderkey"), // streaming
Optional.empty(),
SINGLE,
Expand All @@ -927,7 +928,7 @@ public void testStreamingAggregationOverJoin()
anyTree(
aggregation(
singleGroupingSet("o_orderkey"),
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableMap.of(Optional.empty(), aggregationFunction("count", ImmutableList.of())),
ImmutableList.of("o_orderkey"), // streaming
Optional.empty(),
SINGLE,
Expand All @@ -944,7 +945,7 @@ public void testStreamingAggregationOverJoin()
anyTree(
aggregation(
singleGroupingSet("orderkey"),
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableMap.of(Optional.empty(), aggregationFunction("count", ImmutableList.of())),
ImmutableList.of(), // not streaming
Optional.empty(),
SINGLE,
Expand Down Expand Up @@ -1041,7 +1042,7 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
"exists", expression("FINAL_COUNT > BIGINT '0'")),
aggregation(
singleGroupingSet("ORDERKEY", "UNIQUE"),
ImmutableMap.of(Optional.of("FINAL_COUNT"), functionCall("count", ImmutableList.of())),
ImmutableMap.of(Optional.of("FINAL_COUNT"), aggregationFunction("count", ImmutableList.of())),
ImmutableList.of("ORDERKEY", "UNIQUE"),
ImmutableList.of("NON_NULL"),
Optional.empty(),
Expand Down Expand Up @@ -1071,7 +1072,7 @@ public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin()
.left(tableScan("customer", ImmutableMap.of("c_custkey", "custkey")))
.right(aggregation(
singleGroupingSet("o_custkey"),
ImmutableMap.of(Optional.of("count"), functionCall("count", ImmutableList.of("o_orderkey"))),
ImmutableMap.of(Optional.of("count"), aggregationFunction("count", ImmutableList.of("o_orderkey"))),
ImmutableList.of(),
ImmutableList.of("non_null"),
Optional.empty(),
Expand Down Expand Up @@ -1106,7 +1107,7 @@ public void testCorrelatedDistinctGroupedAggregationRewriteToLeftOuterJoin()
.right(
project(aggregation(
singleGroupingSet("o_orderstatus", "o_custkey"),
ImmutableMap.of(Optional.of("count"), functionCall("count", ImmutableList.of("o_orderkey"))),
ImmutableMap.of(Optional.of("count"), aggregationFunction("count", ImmutableList.of("o_orderkey"))),
Optional.empty(),
SINGLE,
aggregation(
Expand Down Expand Up @@ -1209,7 +1210,7 @@ public void testInlineCountOverLiteral()
"SELECT regionkey, count(1) FROM nation GROUP BY regionkey",
anyTree(
aggregation(
ImmutableMap.of("count_0", functionCall("count", ImmutableList.of())),
ImmutableMap.of("count_0", aggregationFunction("count", ImmutableList.of())),
PARTIAL,
tableScan("nation", ImmutableMap.of("regionkey", "regionkey")))));
}
Expand All @@ -1221,7 +1222,7 @@ public void testInlineCountOverEffectivelyLiteral()
"SELECT regionkey, count(CAST(DECIMAL '1' AS decimal(8,4))) FROM nation GROUP BY regionkey",
anyTree(
aggregation(
ImmutableMap.of("count_0", functionCall("count", ImmutableList.of())),
ImmutableMap.of("count_0", aggregationFunction("count", ImmutableList.of())),
PARTIAL,
tableScan("nation", ImmutableMap.of("regionkey", "regionkey")))));
}
Expand Down Expand Up @@ -2029,7 +2030,7 @@ public void testGroupingSetsWithDefaultValue()
output(
anyTree(
aggregation(
ImmutableMap.of("final_count", functionCall("count", ImmutableList.of("partial_count"))),
ImmutableMap.of("final_count", aggregationFunction("count", ImmutableList.of("partial_count"))),
FINAL,
exchange(
LOCAL,
Expand All @@ -2038,7 +2039,7 @@ public void testGroupingSetsWithDefaultValue()
REMOTE,
REPARTITION,
aggregation(
ImmutableMap.of("partial_count", functionCall("count", ImmutableList.of("CONSTANT"))),
ImmutableMap.of("partial_count", aggregationFunction("count", ImmutableList.of("CONSTANT"))),
PARTIAL,
anyTree(
project(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.trino.sql.planner.OptimizerConfig.JoinDistributionType;
import io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.TableScanNode;
import org.junit.jupiter.api.Test;
Expand All @@ -29,7 +30,6 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.columnReference;
import static io.trino.sql.planner.assertions.PlanMatchPattern.expression;
import static io.trino.sql.planner.assertions.PlanMatchPattern.functionCall;
import static io.trino.sql.planner.assertions.PlanMatchPattern.join;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.output;
Expand Down Expand Up @@ -178,7 +178,7 @@ public void testAggregation()
{
assertMinimallyOptimizedPlan("SELECT COUNT(nationkey) FROM nation",
output(ImmutableList.of("COUNT"),
aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("NATIONKEY"))),
aggregation(ImmutableMap.of("COUNT", PlanMatchPattern.aggregationFunction("count", ImmutableList.of("NATIONKEY"))),
tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED;
import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.functionCall;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
Expand Down Expand Up @@ -147,9 +147,9 @@ void assertTableScanPlannedWithPartitioning(Session session, String table, Conne
String query = "SELECT count(column_b) FROM " + table + " GROUP BY column_a";
assertDistributedPlan(query, session,
anyTree(
aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("COUNT_PART"))), FINAL,
aggregation(ImmutableMap.of("COUNT", aggregationFunction("count", ImmutableList.of("COUNT_PART"))), FINAL,
exchange(LOCAL, REPARTITION,
aggregation(ImmutableMap.of("COUNT_PART", functionCall("count", ImmutableList.of("B"))), PARTIAL,
aggregation(ImmutableMap.of("COUNT_PART", aggregationFunction("count", ImmutableList.of("B"))), PARTIAL,
tableScan(table, ImmutableMap.of("A", "column_a", "B", "column_b")))))));
SubPlan subPlan = subplan(query, OPTIMIZED_AND_VALIDATED, false, session);
assertThat(subPlan.getAllFragments()).hasSize(1);
Expand All @@ -161,10 +161,10 @@ void assertTableScanPlannedWithoutPartitioning(Session session, String table)
String query = "SELECT count(column_b) FROM " + table + " GROUP BY column_a";
assertDistributedPlan("SELECT count(column_b) FROM " + table + " GROUP BY column_a", session,
anyTree(
aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("COUNT_PART"))), FINAL,
aggregation(ImmutableMap.of("COUNT", aggregationFunction("count", ImmutableList.of("COUNT_PART"))), FINAL,
exchange(LOCAL, REPARTITION,
exchange(REMOTE, REPARTITION,
aggregation(ImmutableMap.of("COUNT_PART", functionCall("count", ImmutableList.of("B"))), PARTIAL,
aggregation(ImmutableMap.of("COUNT_PART", aggregationFunction("count", ImmutableList.of("B"))), PARTIAL,
tableScan(table, ImmutableMap.of("A", "column_a", "B", "column_b"))))))));
SubPlan subPlan = subplan(query, OPTIMIZED_AND_VALIDATED, false, session);
assertThat(subPlan.getAllFragments()).hasSize(2);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.assertions;

import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.Expression;

import java.util.List;
import java.util.Optional;

public record AggregationFunction(
String name,
Optional<Symbol> filter,
Optional<OrderingScheme> orderBy,
boolean distinct,
List<Expression> arguments)
{
}
Loading

0 comments on commit 21d29bd

Please sign in to comment.