From beba0b8e0f3ac2b9fd7dd2fc2d79175dda0ee3dd Mon Sep 17 00:00:00 2001 From: Lukasz Stec Date: Fri, 22 Apr 2022 14:40:51 +0200 Subject: [PATCH] Extract AggregationNode.singleAggregation --- .../io/trino/sql/planner/LogicalPlanner.java | 10 +++---- .../io/trino/sql/planner/QueryPlanner.java | 17 ++++-------- .../io/trino/sql/planner/RelationPlanner.java | 10 +++---- ...elateInnerUnnestWithGlobalAggregation.java | 17 ++++-------- ...relateLeftUnnestWithGlobalAggregation.java | 10 +++---- .../rule/PushAggregationThroughOuterJoin.java | 9 +++---- .../rule/RemoveEmptyExceptBranches.java | 12 +++------ .../rule/SetOperationNodeTranslator.java | 9 +++---- .../SingleDistinctAggregationToGroupBy.java | 11 +++----- .../TransformCorrelatedInPredicateToJoin.java | 9 +++---- .../TransformExistsApplyToCorrelatedJoin.java | 10 +++---- ...TransformFilteringSemiJoinToInnerJoin.java | 11 +++----- .../optimizations/PlanNodeDecorrelator.java | 9 +++---- ...tifiedComparisonApplyToCorrelatedJoin.java | 10 +++---- .../sql/planner/plan/AggregationNode.java | 9 +++++++ .../io/trino/cost/TestCostCalculator.java | 9 +++---- .../TestEffectivePredicateExtractor.java | 9 +++---- .../trino/sql/planner/TestTypeValidator.java | 27 +++++-------------- 18 files changed, 68 insertions(+), 140 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index ef11d6fcf67a..595236e44afa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -50,7 +50,6 @@ import io.trino.sql.analyzer.Scope; import io.trino.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; import io.trino.sql.planner.optimizations.PlanOptimizer; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; @@ -134,6 +133,7 @@ import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.QueryPlanner.visibleFields; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.TableWriterNode.CreateReference; import static io.trino.sql.planner.plan.TableWriterNode.InsertReference; @@ -362,15 +362,11 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme PlanNode planNode = new StatisticsWriterNode( idAllocator.getNextId(), - new AggregationNode( + singleAggregation( idAllocator.getNextId(), TableScanNode.newInstance(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.buildOrThrow(), false, Optional.empty()), statisticAggregations.getAggregations(), - singleGroupingSet(groupingSymbols), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()), + singleGroupingSet(groupingSymbols)), new StatisticsWriterNode.WriteStatisticsReference(targetTable), symbolAllocator.newSymbol("rows", BIGINT), tableStatisticsMetadata.getTableStatistics().contains(ROW_COUNT), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 494cc69adc01..a207904874e6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -128,6 +128,7 @@ import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; import static io.trino.sql.planner.plan.AggregationNode.groupingSets; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -333,15 +334,11 @@ public RelationPlan planExpand(Query query) PlanNode result = new UnionNode(idAllocator.getNextId(), nodesToUnion, unionSymbolMapping.build(), unionOutputSymbols); if (union.isDistinct()) { - result = new AggregationNode( + result = singleAggregation( idAllocator.getNextId(), result, ImmutableMap.of(), - singleGroupingSet(result.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(result.getOutputSymbols())); } return new RelationPlan(result, anchorPlan.getScope(), unionOutputSymbols, outerContext); @@ -1654,15 +1651,11 @@ private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node, List< .collect(Collectors.toList()); return subPlan.withNewRoot( - new AggregationNode( + singleAggregation( idAllocator.getNextId(), subPlan.getRoot(), ImmutableMap.of(), - singleGroupingSet(symbols), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty())); + singleGroupingSet(symbols))); } return subPlan; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 29dec22e387a..4ff8db098a76 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -33,7 +33,6 @@ import io.trino.sql.analyzer.Field; import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.ExceptNode; @@ -122,6 +121,7 @@ import static io.trino.sql.planner.QueryPlanner.extractPatternRecognitionExpressions; import static io.trino.sql.planner.QueryPlanner.planWindowSpecification; import static io.trino.sql.planner.QueryPlanner.pruneInvisibleFields; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.tree.Join.Type.CROSS; @@ -1160,14 +1160,10 @@ private SetOperationPlan process(SetOperation node) private PlanNode distinct(PlanNode node) { - return new AggregationNode(idAllocator.getNextId(), + return singleAggregation(idAllocator.getNextId(), node, ImmutableMap.of(), - singleGroupingSet(node.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(node.getOutputSymbols())); } private static final class SetOperationPlan diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java index d0807c088d51..b21137418df5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java @@ -53,6 +53,7 @@ import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; @@ -337,15 +338,11 @@ private static AggregationNode withGroupingAndMask(AggregationNode aggregationNo .build()); } - return new AggregationNode( + return singleAggregation( aggregationNode.getId(), source, rewriteWithMasks(aggregationNode.getAggregations(), masks.buildOrThrow()), - singleGroupingSet(groupingSymbols), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(groupingSymbols)); } private static AggregationNode withGrouping(AggregationNode aggregationNode, List groupingSymbols, PlanNode source) @@ -354,14 +351,10 @@ private static AggregationNode withGrouping(AggregationNode aggregationNode, Lis .distinct() .collect(toImmutableList())); - return new AggregationNode( + return singleAggregation( aggregationNode.getId(), source, aggregationNode.getAggregations(), - groupingSet, - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + groupingSet); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java index a4705a673911..fd95414d8ae1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java @@ -13,7 +13,6 @@ */ package io.trino.sql.planner.iterative.rule; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import io.trino.matching.Captures; @@ -42,6 +41,7 @@ import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation; @@ -264,14 +264,10 @@ private static AggregationNode withGrouping(AggregationNode aggregationNode, Lis .distinct() .collect(toImmutableList())); - return new AggregationNode( + return singleAggregation( aggregationNode.getId(), source, aggregationNode.getAggregations(), - groupingSet, - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + groupingSet); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 502aeea97d04..8e30721d0d00 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -55,6 +55,7 @@ import static io.trino.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct; import static io.trino.sql.planner.optimizations.SymbolMapper.symbolMapper; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.aggregation; import static io.trino.sql.planner.plan.Patterns.join; @@ -309,15 +310,11 @@ private MappedAggregationInfo createAggregationOverNull(AggregationNode referenc Map aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow(); // create an aggregation node whose source is the null row. - AggregationNode aggregationOverNullRow = new AggregationNode( + AggregationNode aggregationOverNullRow = singleAggregation( idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.buildOrThrow(), - globalAggregation(), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + globalAggregation()); return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java index ea25275a0506..8a263bbc75ec 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java @@ -21,8 +21,6 @@ import io.trino.matching.Pattern; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; -import io.trino.sql.planner.plan.AggregationNode; -import io.trino.sql.planner.plan.AggregationNode.Step; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ExceptNode; import io.trino.sql.planner.plan.PlanNode; @@ -30,9 +28,9 @@ import io.trino.sql.planner.plan.ValuesNode; import java.util.List; -import java.util.Optional; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isEmpty; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.except; @@ -96,18 +94,14 @@ public Result apply(ExceptNode node, Captures captures, Context context) if (node.isDistinct()) { return Result.ofPlanNode( - new AggregationNode( + singleAggregation( node.getId(), new ProjectNode( context.getIdAllocator().getNextId(), newSources.get(0), assignments.build()), ImmutableMap.of(), - singleGroupingSet(node.getOutputSymbols()), - ImmutableList.of(), - Step.SINGLE, - Optional.empty(), - Optional.empty())); + singleGroupingSet(node.getOutputSymbols()))); } return Result.ofPlanNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java index 64b9cef1fdfb..0b37ba9d9e0c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java @@ -49,6 +49,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; @@ -180,14 +181,10 @@ private AggregationNode computeCounts(UnionNode sourceNode, List origina Optional.empty())); } - return new AggregationNode(idAllocator.getNextId(), + return singleAggregation(idAllocator.getNextId(), sourceNode, aggregations.buildOrThrow(), - singleGroupingSet(originalColumns), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(originalColumns)); } private WindowNode appendCounts(UnionNode sourceNode, List originalColumns, List markers, List countOutputs, Symbol rowNumberSymbol) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index 9475cf135700..203ea6c476c4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -27,13 +27,12 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.aggregation; import static java.util.Collections.emptyList; @@ -125,18 +124,14 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont return Result.ofPlanNode( new AggregationNode( aggregation.getId(), - new AggregationNode( + singleAggregation( context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.builder() .addAll(aggregation.getGroupingKeys()) .addAll(symbols) - .build()), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()), + .build())), // remove DISTINCT flag from function calls aggregation.getAggregations() .entrySet().stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 8da7a2ea2919..47ce4c17308a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -64,6 +64,7 @@ import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.ExpressionUtils.or; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.Apply.correlation; import static io.trino.sql.planner.plan.Patterns.applyNode; @@ -216,18 +217,14 @@ private PlanNode buildInPredicateEquivalent( Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT); Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT); - AggregationNode aggregation = new AggregationNode( + AggregationNode aggregation = singleAggregation( idAllocator.getNextId(), preProjection, ImmutableMap.builder() .put(countMatchesSymbol, countWithFilter(session, matchConditionSymbol)) .put(countNullMatchesSymbol, countWithFilter(session, nullMatchConditionSymbol)) .buildOrThrow(), - singleGroupingSet(probeSide.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(probeSide.getOutputSymbols())); // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java index 19e66ec4990f..a2e339723706 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java @@ -22,7 +22,6 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.optimizations.PlanNodeDecorrelator; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.Assignments; @@ -47,6 +46,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT; import static io.trino.sql.planner.plan.Patterns.applyNode; @@ -174,7 +174,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Context contex applyNode.getInput(), new ProjectNode( context.getIdAllocator().getNextId(), - new AggregationNode( + singleAggregation( context.getIdAllocator().getNextId(), applyNode.getSubquery(), ImmutableMap.of(count, new Aggregation( @@ -184,11 +184,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Context contex Optional.empty(), Optional.empty(), Optional.empty())), - globalAggregation(), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()), + globalAggregation()), Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), toSqlType(BIGINT))))), applyNode.getCorrelation(), INNER, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java index 1e3ea2442777..5cba9ba69b63 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java @@ -22,7 +22,6 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.optimizations.PlanNodeSearcher; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; @@ -43,7 +42,7 @@ import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.ExpressionUtils.extractConjuncts; import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.Patterns.filter; @@ -124,15 +123,11 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) Optional joinFilter = simplifiedPredicate.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(simplifiedPredicate); - PlanNode filteringSourceDistinct = new AggregationNode( + PlanNode filteringSourceDistinct = singleAggregation( context.getIdAllocator().getNextId(), semiJoin.getFilteringSource(), ImmutableMap.of(), - singleGroupingSet(ImmutableList.of(semiJoin.getFilteringSourceJoinSymbol())), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(semiJoin.getFilteringSourceJoinSymbol()))); JoinNode innerJoin = new JoinNode( semiJoin.getId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java index 1638a0436add..bdc35a9552fa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -64,6 +64,7 @@ import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.optimizations.SymbolMapper.symbolMapper; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.TopNRankingNode.RankingType.ROW_NUMBER; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; @@ -229,15 +230,11 @@ private Optional rewriteLimitWithRowCountOne(DecorrelationR } // rewrite Limit to aggregation on constant symbols - AggregationNode aggregationNode = new AggregationNode( + AggregationNode aggregationNode = singleAggregation( nodeId, decorrelatedChildNode, ImmutableMap.of(), - singleGroupingSet(decorrelatedChildNode.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(decorrelatedChildNode.getOutputSymbols())); return Optional.of(new DecorrelationResult( aggregationNode, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index 69d38052ee7c..f5d086a0d820 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -25,7 +25,6 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeProvider; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.Assignments; @@ -57,6 +56,7 @@ import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -140,7 +140,7 @@ private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparison List outputColumnReferences = ImmutableList.of(outputColumn.toSymbolReference()); - subqueryPlan = new AggregationNode( + subqueryPlan = singleAggregation( idAllocator.getNextId(), subqueryPlan, ImmutableMap.of( @@ -172,11 +172,7 @@ countNonNullValue, new Aggregation( Optional.empty(), Optional.empty(), Optional.empty())), - globalAggregation(), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + globalAggregation()); PlanNode join = new CorrelatedJoinNode( node.getId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java index ab98d65001be..7115b125708c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java @@ -56,6 +56,15 @@ public class AggregationNode private final Optional groupIdSymbol; private final List outputs; + public static AggregationNode singleAggregation( + PlanNodeId id, + PlanNode source, + Map aggregations, + GroupingSetDescriptor groupingSets) + { + return new AggregationNode(id, source, aggregations, groupingSets, ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()); + } + @JsonCreator public AggregationNode( @JsonProperty("id") PlanNodeId id, diff --git a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java index b796d47c1233..bcbaf76c8c75 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java @@ -71,6 +71,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; @@ -824,15 +825,11 @@ private AggregationNode aggregation(String id, PlanNode source) Optional.empty(), Optional.empty()); - return new AggregationNode( + return singleAggregation( new PlanNodeId(id), source, ImmutableMap.of(new Symbol("count"), aggregation), - singleGroupingSet(source.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(source.getOutputSymbols())); } /** diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index ff39b5360640..0ba894957661 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -106,6 +106,7 @@ import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -209,7 +210,7 @@ public void setUp() @Test public void testAggregation() { - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), filter( baseTableScan, @@ -236,11 +237,7 @@ D, new Aggregation( Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(A, B, C)), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(A, B, C))); Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java index 0236be628794..52ed4b72c118 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java @@ -24,7 +24,6 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.VarcharType; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.PlanNode; @@ -57,7 +56,7 @@ import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -178,7 +177,7 @@ public void testValidAggregation() { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( @@ -188,11 +187,7 @@ public void testValidAggregation() Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(columnA, columnB))); assertTypesValid(node); } @@ -234,7 +229,7 @@ public void testInvalidAggregationFunctionCall() { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( @@ -244,11 +239,7 @@ public void testInvalidAggregationFunctionCall() Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(columnA, columnB))); assertThatThrownBy(() -> assertTypesValid(node)) .isInstanceOf(IllegalArgumentException.class) @@ -260,7 +251,7 @@ public void testInvalidAggregationFunctionSignature() { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", BIGINT); - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( @@ -270,11 +261,7 @@ public void testInvalidAggregationFunctionSignature() Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(columnA, columnB))); assertThatThrownBy(() -> assertTypesValid(node)) .isInstanceOf(IllegalArgumentException.class)