From 5c125b5ef0e355b7f89d4927171dc7dd029d0b18 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sat, 24 Sep 2022 14:08:37 +0200 Subject: [PATCH 1/6] Fix formatting of table function table arguments in SqlFormatter --- .../src/main/java/io/trino/sql/SqlFormatter.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 779090070d34..18670f0e8064 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -325,7 +325,11 @@ private void appendTableFunctionArguments(List arguments, protected Void visitTableArgument(TableFunctionTableArgument node, Integer indent) { Relation relation = node.getTable(); - Relation unaliased = relation instanceof AliasedRelation ? ((AliasedRelation) relation).getRelation() : relation; + Node unaliased = relation instanceof AliasedRelation ? ((AliasedRelation) relation).getRelation() : relation; + if (unaliased instanceof TableSubquery) { + // unpack the relation from TableSubquery to avoid adding another pair of parentheses + unaliased = ((TableSubquery) unaliased).getQuery(); + } builder.append("TABLE("); process(unaliased, indent); builder.append(")"); From a6f537d5519e34a4a46a411e6967d585b382c56f Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sat, 24 Sep 2022 14:20:18 +0200 Subject: [PATCH 2/6] Pass plan node tag in the context of PlanPrinter It will be used to append table argument names to nodes being the sources of a table function. --- .../sql/planner/planprinter/PlanPrinter.java | 369 ++++++++++-------- 1 file changed, 214 insertions(+), 155 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 80128e1f57d2..4ddbf2647e05 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -226,7 +226,7 @@ public class PlanPrinter this.representation = new PlanRepresentation(planRoot, types, totalCpuTime, totalScheduledTime, totalBlockedTime); Visitor visitor = new Visitor(types, estimatedStatsAndCosts, stats); - planRoot.accept(visitor, null); + planRoot.accept(visitor, new Context()); } private String toText(boolean verbose, int level) @@ -591,7 +591,7 @@ public static String graphvizDistributedPlan(SubPlan plan) } private class Visitor - extends PlanVisitor + extends PlanVisitor { private final TypeProvider types; private final StatsAndCosts estimatedStatsAndCosts; @@ -605,14 +605,14 @@ public Visitor(TypeProvider types, StatsAndCosts estimatedStatsAndCosts, Optiona } @Override - public Void visitExplainAnalyze(ExplainAnalyzeNode node, Void context) + public Void visitExplainAnalyze(ExplainAnalyzeNode node, Context context) { - addNode(node, "ExplainAnalyze"); - return processChildren(node, context); + addNode(node, "ExplainAnalyze", context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitJoin(JoinNode node, Void context) + public Void visitJoin(JoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { @@ -625,14 +625,14 @@ public Void visitJoin(JoinNode node, Void context) NodeRepresentation nodeOutput; if (node.isCrossJoin()) { checkState(joinExpressions.isEmpty()); - nodeOutput = addNode(node, "CrossJoin"); + nodeOutput = addNode(node, "CrossJoin", context.tag()); } else { ImmutableMap.Builder descriptor = ImmutableMap.builder() .put("criteria", Joiner.on(" AND ").join(anonymizeExpressions(joinExpressions))) .put("hash", formatHash(node.getLeftHashSymbol(), node.getRightHashSymbol())); node.getDistributionType().ifPresent(distribution -> descriptor.put("distribution", distribution.name())); - nodeOutput = addNode(node, node.getType().getJoinLabel(), descriptor.buildOrThrow(), node.getReorderJoinStatsAndCost()); + nodeOutput = addNode(node, node.getType().getJoinLabel(), descriptor.buildOrThrow(), node.getReorderJoinStatsAndCost(), context.tag()); } node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetails("Distribution: %s", distributionType)); @@ -642,61 +642,65 @@ public Void visitJoin(JoinNode node, Void context) if (!node.getDynamicFilters().isEmpty()) { nodeOutput.appendDetails("dynamicFilterAssignments = %s", printDynamicFilterAssignments(node.getDynamicFilters())); } - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitSpatialJoin(SpatialJoinNode node, Void context) + public Void visitSpatialJoin(SpatialJoinNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, node.getType().getJoinLabel(), - ImmutableMap.of("filter", formatFilter(node.getFilter()))); + ImmutableMap.of("filter", formatFilter(node.getFilter())), + context.tag()); nodeOutput.appendDetails("Distribution: %s", node.getDistributionType()); - node.getLeft().accept(this, context); - node.getRight().accept(this, context); + node.getLeft().accept(this, new Context()); + node.getRight().accept(this, new Context()); return null; } @Override - public Void visitSemiJoin(SemiJoinNode node, Void context) + public Void visitSemiJoin(SemiJoinNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, "SemiJoin", ImmutableMap.of( "criteria", anonymizer.anonymize(node.getSourceJoinSymbol()) + " = " + anonymizer.anonymize(node.getFilteringSourceJoinSymbol()), - "hash", formatHash(node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol()))); + "hash", formatHash(node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol())), + context.tag()); node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetails("Distribution: %s", distributionType)); node.getDynamicFilterId().ifPresent(dynamicFilterId -> nodeOutput.appendDetails("dynamicFilterId: %s", dynamicFilterId)); - node.getSource().accept(this, context); - node.getFilteringSource().accept(this, context); + node.getSource().accept(this, new Context()); + node.getFilteringSource().accept(this, new Context()); return null; } @Override - public Void visitDynamicFilterSource(DynamicFilterSourceNode node, Void context) + public Void visitDynamicFilterSource(DynamicFilterSourceNode node, Context context) { addNode( node, "DynamicFilterSource", - ImmutableMap.of("dynamicFilterAssignments", printDynamicFilterAssignments(node.getDynamicFilters()))); - node.getSource().accept(this, context); + ImmutableMap.of("dynamicFilterAssignments", printDynamicFilterAssignments(node.getDynamicFilters())), + context.tag()); + node.getSource().accept(this, new Context()); return null; } @Override - public Void visitIndexSource(IndexSourceNode node, Void context) + public Void visitIndexSource(IndexSourceNode node, Context context) { NodeRepresentation nodeOutput = addNode(node, "IndexSource", ImmutableMap.of( "indexedTable", anonymizer.anonymize(node.getIndexHandle()), - "lookup", formatSymbols(node.getLookupSymbols()))); + "lookup", formatSymbols(node.getLookupSymbols())), + context.tag()); for (Map.Entry entry : node.getAssignments().entrySet()) { if (node.getOutputSymbols().contains(entry.getKey())) { @@ -707,7 +711,7 @@ public Void visitIndexSource(IndexSourceNode node, Void context) } @Override - public Void visitIndexJoin(IndexJoinNode node, Void context) + public Void visitIndexJoin(IndexJoinNode node, Context context) { List joinExpressions = new ArrayList<>(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { @@ -720,47 +724,51 @@ public Void visitIndexJoin(IndexJoinNode node, Void context) format("%sIndexJoin", node.getType().getJoinLabel()), ImmutableMap.of( "criteria", Joiner.on(" AND ").join(anonymizeExpressions(joinExpressions)), - "hash", formatHash(node.getProbeHashSymbol(), node.getIndexHashSymbol()))); - node.getProbeSource().accept(this, context); - node.getIndexSource().accept(this, context); + "hash", formatHash(node.getProbeHashSymbol(), node.getIndexHashSymbol())), + context.tag()); + node.getProbeSource().accept(this, new Context()); + node.getIndexSource().accept(this, new Context()); return null; } @Override - public Void visitOffset(OffsetNode node, Void context) + public Void visitOffset(OffsetNode node, Context context) { addNode(node, "Offset", - ImmutableMap.of("count", String.valueOf(node.getCount()))); - return processChildren(node, context); + ImmutableMap.of("count", String.valueOf(node.getCount())), + context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitLimit(LimitNode node, Void context) + public Void visitLimit(LimitNode node, Context context) { addNode(node, format("Limit%s", node.isPartial() ? "Partial" : ""), ImmutableMap.of( "count", String.valueOf(node.getCount()), "withTies", formatBoolean(node.isWithTies()), - "inputPreSortedBy", formatSymbols(node.getPreSortedInputs()))); - return processChildren(node, context); + "inputPreSortedBy", formatSymbols(node.getPreSortedInputs())), + context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitDistinctLimit(DistinctLimitNode node, Void context) + public Void visitDistinctLimit(DistinctLimitNode node, Context context) { addNode(node, format("DistinctLimit%s", node.isPartial() ? "Partial" : ""), ImmutableMap.of( "limit", String.valueOf(node.getLimit()), - "hash", formatHash(node.getHashSymbol()))); - return processChildren(node, context); + "hash", formatHash(node.getHashSymbol())), + context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitAggregation(AggregationNode node, Void context) + public Void visitAggregation(AggregationNode node, Context context) { String type = ""; if (node.getStep() != AggregationNode.Step.SINGLE) { @@ -777,16 +785,17 @@ public Void visitAggregation(AggregationNode node, Void context) NodeRepresentation nodeOutput = addNode( node, "Aggregate", - ImmutableMap.of("type", type, "keys", keys, "hash", formatHash(node.getHashSymbol()))); + ImmutableMap.of("type", type, "keys", keys, "hash", formatHash(node.getHashSymbol())), + context.tag()); node.getAggregations().forEach((symbol, aggregation) -> nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(symbol), formatAggregation(anonymizer, aggregation))); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitGroupId(GroupIdNode node, Void context) + public Void visitGroupId(GroupIdNode node, Context context) { // grouping sets are easier to understand in terms of inputs List anonymizedInputGroupingSetSymbols = node.getGroupingSets().stream() @@ -799,30 +808,32 @@ public Void visitGroupId(GroupIdNode node, Void context) NodeRepresentation nodeOutput = addNode( node, "GroupId", - ImmutableMap.of("symbols", formatCollection(anonymizedInputGroupingSetSymbols, Objects::toString))); + ImmutableMap.of("symbols", formatCollection(anonymizedInputGroupingSetSymbols, Objects::toString)), + context.tag()); for (Map.Entry mapping : node.getGroupingColumns().entrySet()) { nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(mapping.getKey()), anonymizer.anonymize(mapping.getValue())); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitMarkDistinct(MarkDistinctNode node, Void context) + public Void visitMarkDistinct(MarkDistinctNode node, Context context) { addNode(node, "MarkDistinct", ImmutableMap.of( "distinct", formatOutputs(types, node.getDistinctSymbols()), "marker", anonymizer.anonymize(node.getMarkerSymbol()), - "hash", formatHash(node.getHashSymbol()))); + "hash", formatHash(node.getHashSymbol())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitWindow(WindowNode node, Void context) + public Void visitWindow(WindowNode node, Context context) { ImmutableMap.Builder descriptor = ImmutableMap.builder(); if (!node.getPartitionBy().isEmpty()) { @@ -855,7 +866,8 @@ public Void visitWindow(WindowNode node, Void context) NodeRepresentation nodeOutput = addNode( node, "Window", - descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow()); + descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow(), + context.tag()); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { WindowNode.Function function = entry.getValue(); @@ -868,11 +880,11 @@ public Void visitWindow(WindowNode node, Void context) Joiner.on(", ").join(anonymizeExpressions(function.getArguments())), frameInfo); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitPatternRecognition(PatternRecognitionNode node, Void context) + public Void visitPatternRecognition(PatternRecognitionNode node, Context context) { ImmutableMap.Builder descriptor = ImmutableMap.builder(); if (!node.getPartitionBy().isEmpty()) { @@ -905,7 +917,8 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Void context) NodeRepresentation nodeOutput = addNode( node, "PatterRecognition", - descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow()); + descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow(), + context.tag()); if (node.getCommonBaseFrame().isPresent()) { nodeOutput.appendDetails("base frame: " + formatFrame(node.getCommonBaseFrame().get())); @@ -943,7 +956,7 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Void context) appendValuePointers(nodeOutput, entry.getValue()); } - return processChildren(node, context); + return processChildren(node, new Context()); } private void appendValuePointers(NodeRepresentation nodeOutput, ExpressionAndValuePointers expressionAndPointers) @@ -1048,7 +1061,7 @@ private String formatSkipTo(Position position, Optional label) } @Override - public Void visitTopNRanking(TopNRankingNode node, Void context) + public Void visitTopNRanking(TopNRankingNode node, Context context) { ImmutableMap.Builder descriptor = ImmutableMap.builder(); descriptor.put("partitionBy", formatSymbols(node.getPartitionBy())); @@ -1060,15 +1073,16 @@ public Void visitTopNRanking(TopNRankingNode node, Void context) descriptor .put("limit", String.valueOf(node.getMaxRankingPerPartition())) .put("hash", formatHash(node.getHashSymbol())) - .buildOrThrow()); + .buildOrThrow(), + context.tag()); nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(node.getRankingSymbol()), node.getRankingType()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRowNumber(RowNumberNode node, Void context) + public Void visitRowNumber(RowNumberNode node, Context context) { ImmutableMap.Builder descriptor = ImmutableMap.builder(); if (!node.getPartitionBy().isEmpty()) { @@ -1082,19 +1096,20 @@ public Void visitRowNumber(RowNumberNode node, Void context) NodeRepresentation nodeOutput = addNode( node, "RowNumber", - descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow()); + descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow(), + context.tag()); nodeOutput.appendDetails("%s := %s", anonymizer.anonymize(node.getRowNumberSymbol()), "row_number()"); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableScan(TableScanNode node, Void context) + public Void visitTableScan(TableScanNode node, Context context) { TableHandle table = node.getTable(); TableInfo tableInfo = tableInfoSupplier.apply(node); NodeRepresentation nodeOutput; - nodeOutput = addNode(node, "TableScan", ImmutableMap.of("table", anonymizer.anonymize(table, tableInfo))); + nodeOutput = addNode(node, "TableScan", ImmutableMap.of("table", anonymizer.anonymize(table, tableInfo)), context.tag()); printTableScanInfo(nodeOutput, node, tableInfo); PlanNodeStats nodeStats = stats.map(s -> s.get(node.getId())).orElse(null); if (nodeStats != null) { @@ -1112,9 +1127,9 @@ public Void visitTableScan(TableScanNode node, Void context) } @Override - public Void visitValues(ValuesNode node, Void context) + public Void visitValues(ValuesNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "Values"); + NodeRepresentation nodeOutput = addNode(node, "Values", context.tag()); if (node.getRows().isEmpty()) { for (int i = 0; i < node.getRowCount(); i++) { nodeOutput.appendDetails("()"); @@ -1139,13 +1154,13 @@ public Void visitValues(ValuesNode node, Void context) } @Override - public Void visitFilter(FilterNode node, Void context) + public Void visitFilter(FilterNode node, Context context) { return visitScanFilterAndProjectInfo(node, Optional.of(node), Optional.empty(), context); } @Override - public Void visitProject(ProjectNode node, Void context) + public Void visitProject(ProjectNode node, Context context) { if (node.getSource() instanceof FilterNode) { return visitScanFilterAndProjectInfo(node, Optional.of((FilterNode) node.getSource()), Optional.of(node), context); @@ -1158,7 +1173,7 @@ private Void visitScanFilterAndProjectInfo( PlanNode node, Optional filterNode, Optional projectNode, - Void context) + Context context) { checkState(projectNode.isPresent() || filterNode.isPresent()); @@ -1215,7 +1230,8 @@ private Void visitScanFilterAndProjectInfo( allNodes, ImmutableList.of(sourceNode), ImmutableList.of(), - Optional.empty()); + Optional.empty(), + context.tag()); projectNode.ifPresent(value -> printAssignments(nodeOutput, value.getAssignments())); @@ -1260,7 +1276,7 @@ private Void visitScanFilterAndProjectInfo( return null; } - sourceNode.accept(this, context); + sourceNode.accept(this, new Context()); return null; } @@ -1310,7 +1326,7 @@ private void printTableScanInfo(NodeRepresentation nodeOutput, TableScanNode nod } @Override - public Void visitUnnest(UnnestNode node, Void context) + public Void visitUnnest(UnnestNode node, Context context) { String name; if (node.getFilter().isPresent()) { @@ -1338,17 +1354,18 @@ else if (!node.getReplicateSymbols().isEmpty()) { } descriptor.put("unnest", formatOutputs(types, unnestInputs)); node.getFilter().ifPresent(filter -> descriptor.put("filter", formatFilter(filter))); - addNode(node, name, descriptor.buildOrThrow()); - return processChildren(node, context); + addNode(node, name, descriptor.buildOrThrow(), context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitOutput(OutputNode node, Void context) + public Void visitOutput(OutputNode node, Context context) { NodeRepresentation nodeOutput = addNode( node, "Output", - ImmutableMap.of("columnNames", formatCollection(node.getColumnNames(), anonymizer::anonymizeColumn))); + ImmutableMap.of("columnNames", formatCollection(node.getColumnNames(), anonymizer::anonymizeColumn)), + context.tag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); Symbol symbol = node.getOutputSymbols().get(i); @@ -1356,32 +1373,34 @@ public Void visitOutput(OutputNode node, Void context) nodeOutput.appendDetails("%s := %s", anonymizer.anonymizeColumn(name), anonymizer.anonymize(symbol)); } } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTopN(TopNNode node, Void context) + public Void visitTopN(TopNNode node, Context context) { addNode(node, format("TopN%s", node.getStep() == TopNNode.Step.PARTIAL ? "Partial" : ""), ImmutableMap.of( "count", String.valueOf(node.getCount()), - "orderBy", formatOrderingScheme(node.getOrderingScheme()))); - return processChildren(node, context); + "orderBy", formatOrderingScheme(node.getOrderingScheme())), + context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitSort(SortNode node, Void context) + public Void visitSort(SortNode node, Context context) { addNode(node, format("%sSort", node.isPartial() ? "Partial" : ""), - ImmutableMap.of("orderBy", formatOrderingScheme(node.getOrderingScheme()))); + ImmutableMap.of("orderBy", formatOrderingScheme(node.getOrderingScheme())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRemoteSource(RemoteSourceNode node, Void context) + public Void visitRemoteSource(RemoteSourceNode node, Context context) { addNode(node, format("Remote%s", node.getOrderingScheme().isPresent() ? "Merge" : "Source"), @@ -1389,52 +1408,56 @@ public Void visitRemoteSource(RemoteSourceNode node, Void context) ImmutableList.of(), ImmutableList.of(), node.getSourceFragmentIds(), - Optional.empty()); + Optional.empty(), + context.tag()); return null; } @Override - public Void visitUnion(UnionNode node, Void context) + public Void visitUnion(UnionNode node, Context context) { - addNode(node, "Union"); + addNode(node, "Union", context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitIntersect(IntersectNode node, Void context) + public Void visitIntersect(IntersectNode node, Context context) { addNode(node, "Intersect", - ImmutableMap.of("isDistinct", formatBoolean(node.isDistinct()))); + ImmutableMap.of("isDistinct", formatBoolean(node.isDistinct())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitExcept(ExceptNode node, Void context) + public Void visitExcept(ExceptNode node, Context context) { addNode(node, "Except", - ImmutableMap.of("isDistinct", formatBoolean(node.isDistinct()))); + ImmutableMap.of("isDistinct", formatBoolean(node.isDistinct())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitRefreshMaterializedView(RefreshMaterializedViewNode node, Void context) + public Void visitRefreshMaterializedView(RefreshMaterializedViewNode node, Context context) { addNode(node, "RefreshMaterializedView", - ImmutableMap.of("viewName", anonymizer.anonymize(node.getViewName()))); + ImmutableMap.of("viewName", anonymizer.anonymize(node.getViewName())), + context.tag()); return null; } @Override - public Void visitTableWriter(TableWriterNode node, Void context) + public Void visitTableWriter(TableWriterNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "TableWriter"); + NodeRepresentation nodeOutput = addNode(node, "TableWriter", context.tag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); Symbol symbol = node.getColumns().get(i); @@ -1446,32 +1469,34 @@ public Void visitTableWriter(TableWriterNode node, Void context) printStatisticAggregations(nodeOutput, node.getStatisticsAggregation().get(), node.getStatisticsAggregationDescriptor().get()); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) + public Void visitStatisticsWriterNode(StatisticsWriterNode node, Context context) { addNode(node, "StatisticsWriter", - ImmutableMap.of("target", anonymizer.anonymize(node.getTarget()))); - return processChildren(node, context); + ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), + context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitTableFinish(TableFinishNode node, Void context) + public Void visitTableFinish(TableFinishNode node, Context context) { NodeRepresentation nodeOutput = addNode( node, "TableCommit", - ImmutableMap.of("target", anonymizer.anonymize(node.getTarget()))); + ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), + context.tag()); if (node.getStatisticsAggregation().isPresent()) { verify(node.getStatisticsAggregationDescriptor().isPresent(), "statisticsAggregationDescriptor is not present"); printStatisticAggregations(nodeOutput, node.getStatisticsAggregation().get(), node.getStatisticsAggregationDescriptor().get()); } - return processChildren(node, context); + return processChildren(node, new Context()); } private void printStatisticAggregations(NodeRepresentation nodeOutput, StatisticAggregations aggregations, StatisticAggregationsDescriptor descriptor) @@ -1526,24 +1551,26 @@ private void printStatisticAggregationsInfo( } @Override - public Void visitSample(SampleNode node, Void context) + public Void visitSample(SampleNode node, Context context) { addNode(node, "Sample", ImmutableMap.of( "type", node.getSampleType().name(), - "ratio", String.valueOf(node.getSampleRatio()))); + "ratio", String.valueOf(node.getSampleRatio())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitExchange(ExchangeNode node, Void context) + public Void visitExchange(ExchangeNode node, Context context) { if (node.getOrderingScheme().isPresent()) { addNode(node, format("%sMerge", UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, node.getScope().toString())), - ImmutableMap.of("orderBy", formatOrderingScheme(node.getOrderingScheme().get()))); + ImmutableMap.of("orderBy", formatOrderingScheme(node.getOrderingScheme().get())), + context.tag()); } else if (node.getScope() == Scope.LOCAL) { addNode(node, @@ -1552,7 +1579,8 @@ else if (node.getScope() == Scope.LOCAL) { "partitioning", anonymizer.anonymize(node.getPartitioningScheme().getPartitioning().getHandle()), "isReplicateNullsAndAny", formatBoolean(node.getPartitioningScheme().isReplicateNullsAndAny()), "hashColumn", formatHash(node.getPartitioningScheme().getHashColumn()), - "arguments", formatCollection(node.getPartitioningScheme().getPartitioning().getArguments(), anonymizer::anonymize))); + "arguments", formatCollection(node.getPartitioningScheme().getPartitioning().getArguments(), anonymizer::anonymize)), + context.tag()); } else { addNode(node, @@ -1560,146 +1588,155 @@ else if (node.getScope() == Scope.LOCAL) { ImmutableMap.of( "type", node.getType().name(), "isReplicateNullsAndAny", formatBoolean(node.getPartitioningScheme().isReplicateNullsAndAny()), - "hashColumn", formatHash(node.getPartitioningScheme().getHashColumn()))); + "hashColumn", formatHash(node.getPartitioningScheme().getHashColumn())), + context.tag()); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitDelete(DeleteNode node, Void context) + public Void visitDelete(DeleteNode node, Context context) { addNode(node, "Delete", - ImmutableMap.of("target", anonymizer.anonymize(node.getTarget()))); + ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitUpdate(UpdateNode node, Void context) + public Void visitUpdate(UpdateNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, format("Update[%s]", anonymizer.anonymize(node.getTarget()))); + NodeRepresentation nodeOutput = addNode(node, format("Update[%s]", anonymizer.anonymize(node.getTarget())), context.tag()); int index = 0; for (String columnName : node.getTarget().getUpdatedColumns()) { nodeOutput.appendDetails("%s := %s", anonymizer.anonymizeColumn(columnName), anonymizer.anonymize(node.getColumnValueAndRowIdSymbols().get(index))); index++; } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableExecute(TableExecuteNode node, Void context) + public Void visitTableExecute(TableExecuteNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "TableExecute"); + NodeRepresentation nodeOutput = addNode(node, "TableExecute", context.tag()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); Symbol symbol = node.getColumns().get(i); nodeOutput.appendDetails("%s := %s", anonymizer.anonymizeColumn(name), anonymizer.anonymize(symbol)); } - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitSimpleTableExecuteNode(SimpleTableExecuteNode node, Void context) + public Void visitSimpleTableExecuteNode(SimpleTableExecuteNode node, Context context) { addNode(node, "SimpleTableExecute", - ImmutableMap.of("table", anonymizer.anonymize(node.getExecuteHandle()))); + ImmutableMap.of("table", anonymizer.anonymize(node.getExecuteHandle())), + context.tag()); return null; } @Override - public Void visitMergeWriter(MergeWriterNode node, Void context) + public Void visitMergeWriter(MergeWriterNode node, Context context) { addNode(node, "MergeWriter", - ImmutableMap.of("table", anonymizer.anonymize(node.getTarget()))); - return processChildren(node, context); + ImmutableMap.of("table", anonymizer.anonymize(node.getTarget())), + context.tag()); + return processChildren(node, new Context()); } @Override - public Void visitMergeProcessor(MergeProcessorNode node, Void context) + public Void visitMergeProcessor(MergeProcessorNode node, Context context) { - NodeRepresentation nodeOutput = addNode(node, "MergeProcessor"); + NodeRepresentation nodeOutput = addNode(node, "MergeProcessor", context.tag()); nodeOutput.appendDetails("target: %s", anonymizer.anonymize(node.getTarget())); nodeOutput.appendDetails("merge row column: %s", anonymizer.anonymize(node.getMergeRowSymbol())); nodeOutput.appendDetails("row id column: %s", anonymizer.anonymize(node.getRowIdSymbol())); nodeOutput.appendDetails("redistribution columns: %s", anonymize(node.getRedistributionColumnSymbols())); nodeOutput.appendDetails("data columns: %s", anonymize(node.getDataColumnSymbols())); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableDelete(TableDeleteNode node, Void context) + public Void visitTableDelete(TableDeleteNode node, Context context) { addNode(node, "TableDelete", - ImmutableMap.of("target", anonymizer.anonymize(node.getTarget()))); + ImmutableMap.of("target", anonymizer.anonymize(node.getTarget())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + public Void visitEnforceSingleRow(EnforceSingleRowNode node, Context context) { - addNode(node, "EnforceSingleRow"); + addNode(node, "EnforceSingleRow", context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitAssignUniqueId(AssignUniqueId node, Void context) + public Void visitAssignUniqueId(AssignUniqueId node, Context context) { - addNode(node, "AssignUniqueId"); + addNode(node, "AssignUniqueId", context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitGroupReference(GroupReference node, Void context) + public Void visitGroupReference(GroupReference node, Context context) { addNode(node, "GroupReference", ImmutableMap.of("groupId", String.valueOf(node.getGroupId())), ImmutableList.of(), - Optional.empty()); + Optional.empty(), + context.tag()); return null; } @Override - public Void visitApply(ApplyNode node, Void context) + public Void visitApply(ApplyNode node, Context context) { NodeRepresentation nodeOutput = addNode( node, "Apply", - ImmutableMap.of("correlation", formatSymbols(node.getCorrelation()))); + ImmutableMap.of("correlation", formatSymbols(node.getCorrelation())), + context.tag()); printAssignments(nodeOutput, node.getSubqueryAssignments()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitCorrelatedJoin(CorrelatedJoinNode node, Void context) + public Void visitCorrelatedJoin(CorrelatedJoinNode node, Context context) { addNode(node, "CorrelatedJoin", ImmutableMap.of( "correlation", formatSymbols(node.getCorrelation()), - "filter", formatFilter(node.getFilter()))); + "filter", formatFilter(node.getFilter())), + context.tag()); - return processChildren(node, context); + return processChildren(node, new Context()); } @Override - public Void visitTableFunction(TableFunctionNode node, Void context) + public Void visitTableFunction(TableFunctionNode node, Context context) { NodeRepresentation nodeOutput = addNode( node, "TableFunction", - ImmutableMap.of("name", node.getName())); + ImmutableMap.of("name", node.getName()), + context.tag()); checkArgument( node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), @@ -1722,12 +1759,12 @@ private String formatArgument(ScalarArgument argument) } @Override - protected Void visitPlan(PlanNode node, Void context) + protected Void visitPlan(PlanNode node, Context context) { throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } - private Void processChildren(PlanNode node, Void context) + private Void processChildren(PlanNode node, Context context) { for (PlanNode child : node.getSources()) { child.accept(this, context); @@ -1875,24 +1912,24 @@ private String formatOutputs(TypeProvider types, Iterable outputs) .collect(joining(", ", "[", "]")); } - public NodeRepresentation addNode(PlanNode node, String name) + public NodeRepresentation addNode(PlanNode node, String name, Optional tag) { - return addNode(node, name, ImmutableMap.of()); + return addNode(node, name, ImmutableMap.of(), tag); } - public NodeRepresentation addNode(PlanNode node, String name, Map descriptor) + public NodeRepresentation addNode(PlanNode node, String name, Map descriptor, Optional tag) { - return addNode(node, name, descriptor, node.getSources(), Optional.empty()); + return addNode(node, name, descriptor, node.getSources(), Optional.empty(), tag); } - public NodeRepresentation addNode(PlanNode node, String name, Map descriptor, Optional reorderJoinStatsAndCost) + public NodeRepresentation addNode(PlanNode node, String name, Map descriptor, Optional reorderJoinStatsAndCost, Optional tag) { - return addNode(node, name, descriptor, node.getSources(), reorderJoinStatsAndCost); + return addNode(node, name, descriptor, node.getSources(), reorderJoinStatsAndCost, tag); } - public NodeRepresentation addNode(PlanNode node, String name, Map descriptor, List children, Optional reorderJoinStatsAndCost) + public NodeRepresentation addNode(PlanNode node, String name, Map descriptor, List children, Optional reorderJoinStatsAndCost, Optional tag) { - return addNode(node, name, descriptor, ImmutableList.of(node.getId()), children, ImmutableList.of(), reorderJoinStatsAndCost); + return addNode(node, name, descriptor, ImmutableList.of(node.getId()), children, ImmutableList.of(), reorderJoinStatsAndCost, tag); } public NodeRepresentation addNode( @@ -1902,7 +1939,8 @@ public NodeRepresentation addNode( List allNodes, List children, List remoteSources, - Optional reorderJoinStatsAndCost) + Optional reorderJoinStatsAndCost, + Optional tag) { List childrenIds = children.stream().map(PlanNode::getId).collect(toImmutableList()); List estimatedStats = allNodes.stream() @@ -1911,6 +1949,9 @@ public NodeRepresentation addNode( List estimatedCosts = allNodes.stream() .map(nodeId -> estimatedStatsAndCosts.getCosts().getOrDefault(nodeId, PlanCostEstimate.unknown())) .collect(toList()); + name = tag + .map(tagName -> format("[%s] ", tagName)) + .orElse("") + name; NodeRepresentation nodeOutput = new NodeRepresentation( rootNode.getId(), @@ -1994,4 +2035,22 @@ public Expression rewriteFunctionCall(FunctionCall node, Void context, Expressio } }, expression); } + + private record Context(Optional tag) + { + public Context() + { + this(Optional.empty()); + } + + public Context(String tag) + { + this(Optional.of(tag)); + } + + private Context + { + requireNonNull(tag, "tag is null"); + } + } } From 4666472b0188aa26087840cdb587cc6e4495edef Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 4 Oct 2022 10:32:35 +0200 Subject: [PATCH 3/6] Copy arguments in the constructor --- .../io/trino/sql/planner/plan/TableFunctionNode.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 1fa713174fd0..3c80ec8d7012 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -15,6 +15,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.trino.metadata.TableFunctionHandle; import io.trino.spi.ptf.Argument; import io.trino.sql.planner.Symbol; @@ -51,10 +53,10 @@ public TableFunctionNode( { super(id); this.name = requireNonNull(name, "name is null"); - this.arguments = requireNonNull(arguments, "arguments is null"); - this.properOutputs = requireNonNull(properOutputs, "properOutputs is null"); - this.sources = requireNonNull(sources, "sources is null"); - this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); this.handle = requireNonNull(handle, "handle is null"); } From 1aea489884346822c812b1a242acc286e3e1248e Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:17:47 +0200 Subject: [PATCH 4/6] Refactor TestingTableFunctions --- .../connector/TestingTableFunctions.java | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index e1808f800a9f..6869f8cae3b0 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -45,10 +45,15 @@ public class TestingTableFunctions { private static final String SCHEMA_NAME = "system"; - private static final ConnectorTableFunctionHandle HANDLE = new ConnectorTableFunctionHandle() {}; + private static final String TABLE_NAME = "table"; + private static final String COLUMN_NAME = "column"; + private static final ConnectorTableFunctionHandle HANDLE = new TestingTableFunctionHandle(); private static final TableFunctionAnalysis ANALYSIS = TableFunctionAnalysis.builder() .handle(HANDLE) - .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN))))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .build(); + private static final TableFunctionAnalysis NO_DESCRIPTOR_ANALYSIS = TableFunctionAnalysis.builder() + .handle(HANDLE) .build(); /** @@ -252,9 +257,7 @@ public OnlyPassThroughFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; } } @@ -275,9 +278,7 @@ public MonomorphicStaticReturnTypeFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; } } @@ -300,9 +301,7 @@ public PolymorphicStaticReturnTypeFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; } } @@ -326,9 +325,26 @@ public PassThroughFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return TableFunctionAnalysis.builder() - .handle(HANDLE) - .build(); + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class TestingTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final MockConnectorTableHandle tableHandle; + + public TestingTableFunctionHandle() + { + this.tableHandle = new MockConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + TupleDomain.all(), + Optional.of(ImmutableList.of(new MockConnectorColumnHandle(COLUMN_NAME, BOOLEAN)))); + } + + public MockConnectorTableHandle getTableHandle() + { + return tableHandle; } } } From 80c7fa0519eea07d8417d23908e8d1f8774dc3cd Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Mon, 21 Nov 2022 12:33:59 +0100 Subject: [PATCH 5/6] Extract WindowNode.Specification as a separate class --- .../io/trino/sql/planner/QueryPlanner.java | 13 +-- .../io/trino/sql/planner/RelationPlanner.java | 4 +- .../iterative/rule/DecorrelateUnnest.java | 4 +- .../rule/ImplementLimitWithTies.java | 3 +- ...ushDownDereferencesThroughTopNRanking.java | 4 +- .../PushDownDereferencesThroughWindow.java | 3 +- .../rule/SetOperationNodeTranslator.java | 4 +- .../optimizations/PlanNodeDecorrelator.java | 4 +- .../planner/optimizations/SymbolMapper.java | 5 +- .../plan/DataOrganizationSpecification.java | 82 +++++++++++++++++++ .../planner/plan/PatternRecognitionNode.java | 7 +- .../sql/planner/plan/TableFunctionNode.java | 7 +- .../sql/planner/plan/TopNRankingNode.java | 7 +- .../io/trino/sql/planner/plan/WindowNode.java | 62 +------------- .../trino/sql/planner/TestCanonicalize.java | 4 +- .../TestEffectivePredicateExtractor.java | 3 +- .../trino/sql/planner/TestTypeValidator.java | 7 +- .../assertions/PatternRecognitionMatcher.java | 9 +- .../planner/assertions/PlanMatchPattern.java | 3 +- .../assertions/SpecificationProvider.java | 8 +- .../assertions/TopNRankingMatcher.java | 10 +-- .../sql/planner/assertions/WindowMatcher.java | 9 +- .../rule/TestMergeAdjacentWindows.java | 7 +- .../rule/TestPruneTopNRankingColumns.java | 14 ++-- .../rule/TestPruneWindowColumns.java | 3 +- .../rule/TestPushDownDereferencesRules.java | 5 +- ...PushPredicateThroughProjectIntoWindow.java | 16 ++-- .../rule/TestPushdownFilterIntoWindow.java | 9 +- .../rule/TestPushdownLimitIntoWindow.java | 13 +-- .../rule/TestReplaceWindowWithRowNumber.java | 11 +-- ...stSwapAdjacentWindowsBySpecifications.java | 15 ++-- .../rule/test/PatternRecognitionBuilder.java | 3 +- .../iterative/rule/test/PlanBuilder.java | 8 +- .../optimizations/TestEliminateSorts.java | 4 +- .../optimizations/TestMergeWindows.java | 25 +++--- .../optimizations/TestReorderWindows.java | 16 ++-- ...stPatternRecognitionNodeSerialization.java | 3 +- .../sql/planner/plan/TestWindowNode.java | 2 +- 38 files changed, 229 insertions(+), 187 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/planner/plan/DataOrganizationSpecification.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 5f98c0d25303..34decdc3b7e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -47,6 +47,7 @@ import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.GroupIdNode; @@ -328,7 +329,7 @@ public RelationPlan planExpand(Query query) WindowNode windowNode = new WindowNode( idAllocator.getNextId(), checkConvergenceStep.getNode(), - new WindowNode.Specification(ImmutableList.of(), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(countSymbol, countFunction), Optional.empty(), ImmutableSet.of(), @@ -1829,7 +1830,7 @@ private PlanBuilder planWindow( } } - WindowNode.Specification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get); + DataOrganizationSpecification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get); // Rewrite frame bounds in terms of pre-projected inputs WindowNode.Frame frame = new WindowNode.Frame( @@ -1882,7 +1883,7 @@ private PlanBuilder planPatternRecognition( PlanAndMappings coercions, Optional frameEndSymbol) { - WindowNode.Specification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get); + DataOrganizationSpecification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get); // in window frame with pattern recognition, the frame extent is specified as `ROWS BETWEEN CURRENT ROW AND ... ` WindowFrame frame = window.getFrame().orElseThrow(); @@ -1949,7 +1950,7 @@ private PlanBuilder planPatternRecognition( components.getVariableDefinitions())); } - public static WindowNode.Specification planWindowSpecification(List partitionBy, Optional orderBy, Function expressionRewrite) + public static DataOrganizationSpecification planWindowSpecification(List partitionBy, Optional orderBy, Function expressionRewrite) { // Rewrite PARTITION BY ImmutableList.Builder partitionBySymbols = ImmutableList.builder(); @@ -1970,7 +1971,7 @@ public static WindowNode.Specification planWindowSpecification(List orderingScheme = Optional.of(new OrderingScheme(ImmutableList.copyOf(orderings.keySet()), orderings)); } - return new WindowNode.Specification(partitionBySymbols.build(), orderingScheme); + return new DataOrganizationSpecification(partitionBySymbols.build(), orderingScheme); } private PlanBuilder planWindowMeasures(Node node, PlanBuilder subPlan, List windowMeasures) @@ -2031,7 +2032,7 @@ private PlanBuilder planPatternRecognition( ResolvedWindow window, Optional frameEndSymbol) { - WindowNode.Specification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), subPlan::translate); + DataOrganizationSpecification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), subPlan::translate); // in window frame with pattern recognition, the frame extent is specified as `ROWS BETWEEN CURRENT ROW AND ... ` WindowFrame frame = window.getFrame().orElseThrow(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 19d2d37b0d40..9a4c91a11cd6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -34,6 +34,7 @@ import io.trino.sql.analyzer.Scope; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.ExceptNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.IntersectNode; @@ -49,7 +50,6 @@ import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UnnestNode; import io.trino.sql.planner.plan.ValuesNode; -import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.planner.rowpattern.LogicalIndexExtractor; import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers; import io.trino.sql.planner.rowpattern.RowPatternToIrRewriter; @@ -416,7 +416,7 @@ protected RelationPlan visitPatternRecognitionRelation(PatternRecognitionRelatio ImmutableList.Builder outputLayout = ImmutableList.builder(); boolean oneRowOutput = node.getRowsPerMatch().isEmpty() || node.getRowsPerMatch().get().isOneRow(); - WindowNode.Specification specification = planWindowSpecification(node.getPartitionBy(), node.getOrderBy(), planBuilder::translate); + DataOrganizationSpecification specification = planWindowSpecification(node.getPartitionBy(), node.getOrderBy(), planBuilder::translate); outputLayout.addAll(specification.getPartitionBy()); if (!oneRowOutput) { getSortItemsFromOrderBy(node.getOrderBy()).stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java index e937e68b02d2..6373e5700d05 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java @@ -30,6 +30,7 @@ import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.EnforceSingleRowNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode.Type; @@ -41,7 +42,6 @@ import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.UnnestNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.tree.Cast; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; @@ -473,7 +473,7 @@ public RewriteResult visitTopN(TopNNode node, Void context) WindowNode windowNode = new WindowNode( idAllocator.getNextId(), source.getPlan(), - new Specification(ImmutableList.of(uniqueSymbol), Optional.of(node.getOrderingScheme())), + new DataOrganizationSpecification(ImmutableList.of(uniqueSymbol), Optional.of(node.getOrderingScheme())), ImmutableMap.of(rowNumberSymbol, rowNumberFunction), Optional.empty(), ImmutableSet.of(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java index c57c9435e4a6..7bdaf6e76546 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java @@ -26,6 +26,7 @@ import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.PlanNode; @@ -128,7 +129,7 @@ public static PlanNode rewriteLimitWithTiesWithPartitioning(LimitNode limitNode, WindowNode windowNode = new WindowNode( idAllocator.getNextId(), source, - new WindowNode.Specification(partitionBy, limitNode.getTiesResolvingScheme()), + new DataOrganizationSpecification(partitionBy, limitNode.getTiesResolvingScheme()), ImmutableMap.of(rankSymbol, rankFunction), Optional.empty(), ImmutableSet.of(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java index 7981b59d03a3..e0bcb2dec420 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRanking.java @@ -23,9 +23,9 @@ import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.TopNRankingNode; -import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.Expression; import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SymbolReference; @@ -89,7 +89,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) Set dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()); // Exclude dereferences on symbols being used in partitionBy and orderBy - WindowNode.Specification specification = topNRankingNode.getSpecification(); + DataOrganizationSpecification specification = topNRankingNode.getSpecification(); dereferences = dereferences.stream() .filter(expression -> { Symbol symbol = getBase(expression); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java index ecb826aa68f4..fbf6a829a28f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java @@ -23,6 +23,7 @@ import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.Expression; @@ -99,7 +100,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) typeAnalyzer, context.getSymbolAllocator().getTypes()); - WindowNode.Specification specification = windowNode.getSpecification(); + DataOrganizationSpecification specification = windowNode.getSpecification(); dereferences = dereferences.stream() .filter(expression -> { Symbol symbol = getBase(expression); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java index 0b37ba9d9e0c..ff98e767aecf 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java @@ -26,12 +26,12 @@ import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.SetOperationNode; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.tree.Cast; import io.trino.sql.tree.Expression; import io.trino.sql.tree.NullLiteral; @@ -210,7 +210,7 @@ private WindowNode appendCounts(UnionNode sourceNode, List originalColum return new WindowNode( idAllocator.getNextId(), sourceNode, - new Specification(originalColumns, Optional.empty()), + new DataOrganizationSpecification(originalColumns, Optional.empty()), functions.buildOrThrow(), Optional.empty(), ImmutableSet.of(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java index b614e92db70c..dbd5824433e0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -34,6 +34,7 @@ import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.EnforceSingleRowNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; @@ -44,7 +45,6 @@ import io.trino.sql.planner.plan.RowNumberNode; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.TopNRankingNode; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.tree.Cast; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; @@ -336,7 +336,7 @@ public Optional visitTopN(TopNNode node, Void context) TopNRankingNode topNRankingNode = new TopNRankingNode( node.getId(), decorrelatedChildNode, - new Specification( + new DataOrganizationSpecification( ImmutableList.copyOf(childDecorrelationResult.symbolsToPropagate), Optional.of(orderingScheme)), ROW_NUMBER, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 395217bef8a4..5048da2d530e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -22,6 +22,7 @@ import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.GroupIdNode; import io.trino.sql.planner.plan.LimitNode; @@ -239,9 +240,9 @@ private WindowNode.Frame map(WindowNode.Frame frame) frame.getOriginalEndValue()); } - private WindowNode.Specification mapAndDistinct(WindowNode.Specification specification) + private DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) { - return new WindowNode.Specification( + return new DataOrganizationSpecification( mapAndDistinct(specification.getPartitionBy()), specification.getOrderingScheme().map(this::map)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/DataOrganizationSpecification.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/DataOrganizationSpecification.java new file mode 100644 index 000000000000..588a83c96967 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/DataOrganizationSpecification.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Immutable +public class DataOrganizationSpecification +{ + private final List partitionBy; + private final Optional orderingScheme; + + @JsonCreator + public DataOrganizationSpecification( + @JsonProperty("partitionBy") List partitionBy, + @JsonProperty("orderingScheme") Optional orderingScheme) + { + requireNonNull(partitionBy, "partitionBy is null"); + requireNonNull(orderingScheme, "orderingScheme is null"); + + this.partitionBy = ImmutableList.copyOf(partitionBy); + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + } + + @JsonProperty + public List getPartitionBy() + { + return partitionBy; + } + + @JsonProperty + public Optional getOrderingScheme() + { + return orderingScheme; + } + + @Override + public int hashCode() + { + return Objects.hash(partitionBy, orderingScheme); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + DataOrganizationSpecification other = (DataOrganizationSpecification) obj; + + return Objects.equals(this.partitionBy, other.partitionBy) && + Objects.equals(this.orderingScheme, other.orderingScheme); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java index e2aa4b24c422..d3994564efc9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java @@ -23,7 +23,6 @@ import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.WindowNode.Frame; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers; import io.trino.sql.planner.rowpattern.ir.IrLabel; import io.trino.sql.planner.rowpattern.ir.IrRowPattern; @@ -52,7 +51,7 @@ public class PatternRecognitionNode extends PlanNode { private final PlanNode source; - private final Specification specification; + private final DataOrganizationSpecification specification; private final Optional hashSymbol; private final Set prePartitionedInputs; private final int preSortedOrderPrefix; @@ -81,7 +80,7 @@ Because the base frame is common to all window functions (and measures), it is a public PatternRecognitionNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("hashSymbol") Optional hashSymbol, @JsonProperty("prePartitionedInputs") Set prePartitionedInputs, @JsonProperty("preSortedOrderPrefix") int preSortedOrderPrefix, @@ -173,7 +172,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 3c80ec8d7012..ad1b20e26426 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -20,7 +20,6 @@ import io.trino.metadata.TableFunctionHandle; import io.trino.spi.ptf.Argument; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.plan.WindowNode.Specification; import javax.annotation.concurrent.Immutable; @@ -122,14 +121,14 @@ public static class TableArgumentProperties private final boolean rowSemantics; private final boolean pruneWhenEmpty; private final boolean passThroughColumns; - private final Specification specification; + private final DataOrganizationSpecification specification; @JsonCreator public TableArgumentProperties( @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughColumns") boolean passThroughColumns, - @JsonProperty("specification") Specification specification) + @JsonProperty("specification") DataOrganizationSpecification specification) { this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; @@ -156,7 +155,7 @@ public boolean isPassThroughColumns() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TopNRankingNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TopNRankingNode.java index 38b0739f5d9d..6183b44a9038 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TopNRankingNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TopNRankingNode.java @@ -19,7 +19,6 @@ import com.google.common.collect.Iterables; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.plan.WindowNode.Specification; import javax.annotation.concurrent.Immutable; @@ -42,7 +41,7 @@ public enum RankingType } private final PlanNode source; - private final Specification specification; + private final DataOrganizationSpecification specification; private final RankingType rankingType; private final Symbol rankingSymbol; private final int maxRankingPerPartition; @@ -53,7 +52,7 @@ public enum RankingType public TopNRankingNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("rankingType") RankingType rankingType, @JsonProperty("rankingSymbol") Symbol rankingSymbol, @JsonProperty("maxRankingPerPartition") int maxRankingPerPartition, @@ -101,7 +100,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java index 8f9c23e16006..7b4bd5ee9fbe 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java @@ -48,7 +48,7 @@ public class WindowNode { private final PlanNode source; private final Set prePartitionedInputs; - private final Specification specification; + private final DataOrganizationSpecification specification; private final int preSortedOrderPrefix; private final Map windowFunctions; private final Optional hashSymbol; @@ -57,7 +57,7 @@ public class WindowNode public WindowNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("specification") Specification specification, + @JsonProperty("specification") DataOrganizationSpecification specification, @JsonProperty("windowFunctions") Map windowFunctions, @JsonProperty("hashSymbol") Optional hashSymbol, @JsonProperty("prePartitionedInputs") Set prePartitionedInputs, @@ -111,7 +111,7 @@ public PlanNode getSource() } @JsonProperty - public Specification getSpecification() + public DataOrganizationSpecification getSpecification() { return specification; } @@ -123,7 +123,7 @@ public List getPartitionBy() public Optional getOrderingScheme() { - return specification.orderingScheme; + return specification.getOrderingScheme(); } @JsonProperty @@ -169,60 +169,6 @@ public PlanNode replaceChildren(List newChildren) return new WindowNode(getId(), Iterables.getOnlyElement(newChildren), specification, windowFunctions, hashSymbol, prePartitionedInputs, preSortedOrderPrefix); } - @Immutable - public static class Specification - { - private final List partitionBy; - private final Optional orderingScheme; - - @JsonCreator - public Specification( - @JsonProperty("partitionBy") List partitionBy, - @JsonProperty("orderingScheme") Optional orderingScheme) - { - requireNonNull(partitionBy, "partitionBy is null"); - requireNonNull(orderingScheme, "orderingScheme is null"); - - this.partitionBy = ImmutableList.copyOf(partitionBy); - this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); - } - - @JsonProperty - public List getPartitionBy() - { - return partitionBy; - } - - @JsonProperty - public Optional getOrderingScheme() - { - return orderingScheme; - } - - @Override - public int hashCode() - { - return Objects.hash(partitionBy, orderingScheme); - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - Specification other = (Specification) obj; - - return Objects.equals(this.partitionBy, other.partitionBy) && - Objects.equals(this.orderingScheme, other.orderingScheme); - } - } - @Immutable public static class Frame { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java index 4df96dcf07a7..e0d7dd657603 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestCanonicalize.java @@ -22,7 +22,7 @@ import io.trino.sql.planner.iterative.IterativeOptimizer; import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; import io.trino.sql.planner.optimizations.UnaliasSymbolReferences; -import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.LongLiteral; import org.testng.annotations.Test; @@ -55,7 +55,7 @@ public void testJoin() @Test public void testDuplicatesInWindowOrderBy() { - ExpectedValueProvider specification = specification( + ExpectedValueProvider specification = specification( ImmutableList.of(), ImmutableList.of("A"), ImmutableMap.of("A", SortOrder.ASC_NULLS_LAST)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index 9ca9bdfb3708..38e1f8646d38 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -44,6 +44,7 @@ import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; @@ -445,7 +446,7 @@ public void testWindow() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(A), Optional.of(new OrderingScheme( ImmutableList.of(A), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java index 52ed4b72c118..5830ec6628fa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java @@ -26,6 +26,7 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.ProjectNode; @@ -158,7 +159,7 @@ public void testValidWindow() WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( newId(), @@ -287,7 +288,7 @@ public void testInvalidWindowFunctionCall() WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnA.toSymbolReference()), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( newId(), @@ -322,7 +323,7 @@ public void testInvalidWindowFunctionSignature() WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), frame, false); - WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); + DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); PlanNode node = new WindowNode( newId(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PatternRecognitionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PatternRecognitionMatcher.java index 1538a53d516d..78d932baf398 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PatternRecognitionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PatternRecognitionMatcher.java @@ -19,6 +19,7 @@ import io.trino.cost.StatsProvider; import io.trino.metadata.Metadata; import io.trino.spi.type.Type; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.WindowNode; @@ -53,7 +54,7 @@ public class PatternRecognitionMatcher implements Matcher { - private final Optional> specification; + private final Optional> specification; private final Optional> frame; private final RowsPerMatch rowsPerMatch; private final Optional skipToLabel; @@ -64,7 +65,7 @@ public class PatternRecognitionMatcher private final Map variableDefinitions; private PatternRecognitionMatcher( - Optional> specification, + Optional> specification, Optional> frame, RowsPerMatch rowsPerMatch, Optional skipToLabel, @@ -179,7 +180,7 @@ public String toString() public static class Builder { private final PlanMatchPattern source; - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private final List windowFunctionMatchers = new LinkedList<>(); private final Map> measures = new HashMap<>(); private Optional> frame = Optional.empty(); @@ -198,7 +199,7 @@ public static class Builder } @CanIgnoreReturnValue - public Builder specification(ExpectedValueProvider specification) + public Builder specification(ExpectedValueProvider specification) { this.specification = Optional.of(specification); return this; diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 33c266a24260..16ead964c9e5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -35,6 +35,7 @@ import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; import io.trino.sql.planner.plan.ExceptNode; @@ -1073,7 +1074,7 @@ private static List toSymbolAliases(List aliases) .collect(toImmutableList()); } - public static ExpectedValueProvider specification( + public static ExpectedValueProvider specification( List partitionBy, List orderBy, Map orderings) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SpecificationProvider.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SpecificationProvider.java index 832ef49df9de..85ac5933dfd0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SpecificationProvider.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/SpecificationProvider.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.SortOrder; import io.trino.sql.planner.OrderingScheme; -import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import java.util.List; import java.util.Map; @@ -29,7 +29,7 @@ import static java.util.Objects.requireNonNull; public class SpecificationProvider - implements ExpectedValueProvider + implements ExpectedValueProvider { private final List partitionBy; private final List orderBy; @@ -46,7 +46,7 @@ public class SpecificationProvider } @Override - public WindowNode.Specification getExpectedValue(SymbolAliases aliases) + public DataOrganizationSpecification getExpectedValue(SymbolAliases aliases) { Optional orderingScheme = Optional.empty(); if (!orderBy.isEmpty()) { @@ -61,7 +61,7 @@ public WindowNode.Specification getExpectedValue(SymbolAliases aliases) .collect(toImmutableMap(entry -> entry.getKey().toSymbol(aliases), Map.Entry::getValue)))); } - return new WindowNode.Specification( + return new DataOrganizationSpecification( partitionBy .stream() .map(alias -> alias.toSymbol(aliases)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TopNRankingMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TopNRankingMatcher.java index ec7550500c6e..f11bfed8d472 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TopNRankingMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TopNRankingMatcher.java @@ -18,10 +18,10 @@ import io.trino.metadata.Metadata; import io.trino.spi.connector.SortOrder; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.TopNRankingNode; import io.trino.sql.planner.plan.TopNRankingNode.RankingType; -import io.trino.sql.planner.plan.WindowNode; import java.util.List; import java.util.Map; @@ -37,7 +37,7 @@ public class TopNRankingMatcher implements Matcher { - private final Optional> specification; + private final Optional> specification; private final Optional rankingSymbol; private final Optional rankingType; private final Optional maxRankingPerPartition; @@ -45,7 +45,7 @@ public class TopNRankingMatcher private final Optional> hashSymbol; private TopNRankingMatcher( - Optional> specification, + Optional> specification, Optional rankingSymbol, Optional rankingType, Optional maxRankingPerPartition, @@ -74,7 +74,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses TopNRankingNode topNRankingNode = (TopNRankingNode) node; if (specification.isPresent()) { - WindowNode.Specification expected = specification.get().getExpectedValue(symbolAliases); + DataOrganizationSpecification expected = specification.get().getExpectedValue(symbolAliases); if (!expected.equals(topNRankingNode.getSpecification())) { return NO_MATCH; } @@ -131,7 +131,7 @@ public String toString() public static class Builder { private final PlanMatchPattern source; - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private Optional rankingSymbol = Optional.empty(); private Optional rankingType = Optional.empty(); private Optional maxRankingPerPartition = Optional.empty(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowMatcher.java index d69009da2992..ca2df05b2387 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowMatcher.java @@ -18,6 +18,7 @@ import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.spi.connector.SortOrder; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.FunctionCall; @@ -43,13 +44,13 @@ public final class WindowMatcher implements Matcher { private final Optional> prePartitionedInputs; - private final Optional> specification; + private final Optional> specification; private final Optional preSortedOrderPrefix; private final Optional> hashSymbol; private WindowMatcher( Optional> prePartitionedInputs, - Optional> specification, + Optional> specification, Optional preSortedOrderPrefix, Optional> hashSymbol) { @@ -133,7 +134,7 @@ public static class Builder { private final PlanMatchPattern source; private Optional> prePartitionedInputs = Optional.empty(); - private Optional> specification = Optional.empty(); + private Optional> specification = Optional.empty(); private Optional preSortedOrderPrefix = Optional.empty(); private final List windowFunctionMatchers = new LinkedList<>(); private Optional> hashSymbol = Optional.empty(); @@ -161,7 +162,7 @@ public Builder specification( return specification(PlanMatchPattern.specification(partitionBy, orderBy, orderings)); } - public Builder specification(ExpectedValueProvider specification) + public Builder specification(ExpectedValueProvider specification) { requireNonNull(specification, "specification is null"); this.specification = Optional.of(specification); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 300afc92b10d..317ce4b37b76 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -22,6 +22,7 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.SymbolReference; @@ -51,7 +52,7 @@ public class TestMergeAdjacentWindows private static final ResolvedFunction LAG = FUNCTION_RESOLUTION.resolveFunction(QualifiedName.of("lag"), fromTypes(DOUBLE)); private static final String columnAAlias = "ALIAS_A"; - private static final ExpectedValueProvider specificationA = + private static final ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); @Test @@ -203,9 +204,9 @@ public void testIntermediateProjectNodes() values(columnAAlias, unusedAlias)))))); } - private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) + private static DataOrganizationSpecification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) { - return new WindowNode.Specification(ImmutableList.of(planBuilder.symbol(symbolName, BIGINT)), Optional.empty()); + return new DataOrganizationSpecification(ImmutableList.of(planBuilder.symbol(symbolName, BIGINT)), Optional.empty()); } private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolvedFunction, String... symbols) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java index 9ca0841a2fc4..ba8995c636d7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTopNRankingColumns.java @@ -21,7 +21,7 @@ import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; -import io.trino.sql.planner.plan.WindowNode.Specification; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import org.testng.annotations.Test; import java.util.Optional; @@ -46,7 +46,7 @@ public void testDoNotPrunePartitioningSymbol() return p.project( Assignments.identity(b, ranking), p.topNRanking( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(a), Optional.of(new OrderingScheme(ImmutableList.of(b), ImmutableMap.of(b, SortOrder.ASC_NULLS_FIRST)))), ROW_NUMBER, @@ -68,7 +68,7 @@ public void testDoNotPruneOrderingSymbol() return p.project( Assignments.identity(ranking), p.topNRanking( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)))), ROW_NUMBER, @@ -91,7 +91,7 @@ public void testDoNotPruneHashSymbol() return p.project( Assignments.identity(a, ranking), p.topNRanking( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)))), ROW_NUMBER, @@ -114,7 +114,7 @@ public void testSourceSymbolNotReferenced() return p.project( Assignments.identity(a, ranking), p.topNRanking( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)))), ROW_NUMBER, @@ -151,7 +151,7 @@ public void testAllSymbolsReferenced() return p.project( Assignments.identity(a, b, ranking), p.topNRanking( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)))), ROW_NUMBER, @@ -173,7 +173,7 @@ public void testRankingSymbolNotReferenced() return p.project( Assignments.identity(a), p.topNRanking( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)))), ROW_NUMBER, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java index 71abfb92047e..7815bcf45dac 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -28,6 +28,7 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.QualifiedName; @@ -207,7 +208,7 @@ private static PlanNode buildProjectedWindow( .filter(projectionFilter) .collect(toImmutableList())), p.window( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(partitionKey), Optional.of(new OrderingScheme( ImmutableList.of(orderKey), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java index ab3e95f8ce1d..a864efdfce42 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -28,6 +28,7 @@ import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.UnnestNode; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.FrameBound; @@ -518,7 +519,7 @@ public void testPushdownDereferenceThroughTopNRanking() .put(p.symbol("msg3_x"), expression("msg3[1]")) .build(), p.topNRanking( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(p.symbol("msg1", ROW_TYPE)), Optional.of(new OrderingScheme( ImmutableList.of(p.symbol("msg2", ROW_TYPE)), @@ -589,7 +590,7 @@ public void testPushdownDereferenceThroughWindow() .put(p.symbol("msg5_x"), expression("msg5[1]")) .build(), p.window( - new WindowNode.Specification( + new DataOrganizationSpecification( ImmutableList.of(p.symbol("msg1", ROW_TYPE)), Optional.of(new OrderingScheme( ImmutableList.of(p.symbol("msg2", ROW_TYPE)), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java index 79aa054d6a90..d5d66669512e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java @@ -22,9 +22,9 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.TopNRankingNode.RankingType; import io.trino.sql.planner.plan.WindowNode.Function; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.tree.QualifiedName; import org.testng.annotations.Test; @@ -63,7 +63,7 @@ private void assertRankingSymbolPruned(Function rankingFunction) p.project( Assignments.identity(a), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), @@ -90,7 +90,7 @@ private void assertNoUpperBoundForRankingSymbol(Function rankingFunction) p.project( Assignments.identity(a, ranking), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), @@ -117,7 +117,7 @@ private void assertNonPositiveUpperBoundForRankingSymbol(Function rankingFunctio p.project( Assignments.identity(a, ranking), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), @@ -144,7 +144,7 @@ private void assertPredicateNotSatisfied(Function rankingFunction, RankingType r p.project( Assignments.identity(ranking), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), @@ -185,7 +185,7 @@ private void assertPredicateSatisfied(Function rankingFunction, RankingType rank p.project( Assignments.identity(ranking), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), @@ -224,7 +224,7 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking p.project( Assignments.identity(ranking, a), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), @@ -255,7 +255,7 @@ private void assertPredicatePartiallySatisfied(Function rankingFunction, Ranking p.project( Assignments.identity(ranking), p.window( - new Specification( + new DataOrganizationSpecification( ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, ASC_NULLS_FIRST)))), ImmutableMap.of(ranking, rankingFunction), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java index 0cf66000d14b..7daefc9279c0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java @@ -21,6 +21,7 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.QualifiedName; import org.testng.annotations.Test; @@ -56,7 +57,7 @@ private void assertEliminateFilter(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter(expression("rank_1 < cast(100 as bigint)"), p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rankSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a")))); }) @@ -84,7 +85,7 @@ private void assertKeepFilter(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter(expression("cast(3 as bigint) < row_number_1 and row_number_1 < cast(100 as bigint)"), p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a")))); }) @@ -107,7 +108,7 @@ private void assertKeepFilter(String rankingFunctionName) ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.filter(expression("row_number_1 < cast(100 as bigint) and a = BIGINT '1'"), p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a")))); }) @@ -143,7 +144,7 @@ private void assertNoUpperBound(String rankingFunctionName) return p.filter( expression("cast(3 as bigint) < row_number_1"), p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(a))); }) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java index c3e445c8c512..be39f290ea02 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java @@ -20,6 +20,7 @@ import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.QualifiedName; import org.testng.annotations.Test; @@ -55,7 +56,7 @@ private void assertLimitAboveWindow(String rankingFunctionName) return p.limit( 3, p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(a))); }) @@ -82,7 +83,7 @@ public void testConvertToTopNRowNumber() ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); return p.limit(3, p.window( - new WindowNode.Specification(ImmutableList.of(), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(a))); }) @@ -114,7 +115,7 @@ public void testLimitWithPreSortedInputs() false, ImmutableList.of(a), p.window( - new WindowNode.Specification(ImmutableList.of(), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(a))); }) @@ -141,7 +142,7 @@ private void assertZeroLimit(String rankingFunctionName) return p.limit( 0, p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(a))); }) @@ -165,7 +166,7 @@ private void assertWindowNotOrdered(String rankingFunctionName) return p.limit( 3, p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.empty()), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(a))); }) @@ -185,7 +186,7 @@ public void testMultipleWindowFunctions() return p.limit( 3, p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.empty()), ImmutableMap.of( rowNumberSymbol, newWindowNodeFunction(rowNumberFunction, a), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java index 70bae2c8dc06..ee6d7438587c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java @@ -20,6 +20,7 @@ import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.QualifiedName; import org.testng.annotations.Test; @@ -43,7 +44,7 @@ public void test() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.empty()), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(rowNumberFunction)), p.values(a)); }) @@ -58,7 +59,7 @@ public void test() Symbol a = p.symbol("a"); Symbol rowNumberSymbol = p.symbol("row_number_1"); return p.window( - new WindowNode.Specification(ImmutableList.of(), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(rowNumberFunction)), p.values(a)); }) @@ -78,7 +79,7 @@ public void testDoNotFire() Symbol a = p.symbol("a"); Symbol rank1 = p.symbol("rank_1"); return p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.empty()), ImmutableMap.of(rank1, newWindowNodeFunction(rank)), p.values(a)); }) @@ -91,7 +92,7 @@ public void testDoNotFire() Symbol rowNumber1 = p.symbol("row_number_1"); Symbol rank1 = p.symbol("rank_1"); return p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.empty()), ImmutableMap.of(rowNumber1, newWindowNodeFunction(rowNumber), rank1, newWindowNodeFunction(rank)), p.values(a)); }) @@ -103,7 +104,7 @@ public void testDoNotFire() OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST)); Symbol rowNumber1 = p.symbol("row_number_1"); return p.window( - new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), + new DataOrganizationSpecification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumber1, newWindowNodeFunction(rowNumber)), p.values(a)); }) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 76ca49adf12e..0cb9bbb897fe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -19,6 +19,7 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.sql.planner.assertions.ExpectedValueProvider; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.SymbolReference; @@ -57,7 +58,7 @@ public void doesNotFireOnPlanWithoutWindowFunctions() public void doesNotFireOnPlanWithSingleWindowNode() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) - .on(p -> p.window(new WindowNode.Specification( + .on(p -> p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1"), @@ -72,17 +73,17 @@ public void subsetComesFirst() String columnAAlias = "ALIAS_A"; String columnBAlias = "ALIAS_B"; - ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); - ExpectedValueProvider specificationAB = specification(ImmutableList.of(columnAAlias, columnBAlias), ImmutableList.of(), ImmutableMap.of()); + ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); + ExpectedValueProvider specificationAB = specification(ImmutableList.of(columnAAlias, columnBAlias), ImmutableList.of(), ImmutableMap.of()); tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), new WindowNode.Function(resolvedFunction, ImmutableList.of(new SymbolReference("a")), DEFAULT_FRAME, false)), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), @@ -103,12 +104,12 @@ public void dependentWindowsAreNotReordered() { tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0)) .on(p -> - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1"), new WindowNode.Function(resolvedFunction, ImmutableList.of(new SymbolReference("avg_2")), DEFAULT_FRAME, false)), - p.window(new WindowNode.Specification( + p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PatternRecognitionBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PatternRecognitionBuilder.java index 4138f426e1e7..d678618e4c62 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PatternRecognitionBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PatternRecognitionBuilder.java @@ -20,6 +20,7 @@ import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PatternRecognitionNode.Measure; import io.trino.sql.planner.plan.PlanNode; @@ -156,7 +157,7 @@ public PatternRecognitionNode build(PlanNodeIdAllocator idAllocator) return new PatternRecognitionNode( idAllocator.getNextId(), source, - new WindowNode.Specification(partitionBy, orderBy), + new DataOrganizationSpecification(partitionBy, orderBy), Optional.empty(), ImmutableSet.of(), 0, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 6573e7bc4f15..b39a2df31d4a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -54,6 +54,7 @@ import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -103,7 +104,6 @@ import io.trino.sql.planner.plan.UpdateNode; import io.trino.sql.planner.plan.ValuesNode; import io.trino.sql.planner.plan.WindowNode; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.NullLiteral; @@ -1344,7 +1344,7 @@ public UnnestNode unnest(List replicateSymbols, List filter); } - public WindowNode window(Specification specification, Map functions, PlanNode source) + public WindowNode window(DataOrganizationSpecification specification, Map functions, PlanNode source) { return new WindowNode( idAllocator.getNextId(), @@ -1356,7 +1356,7 @@ public WindowNode window(Specification specification, Map functions, Symbol hashSymbol, PlanNode source) + public WindowNode window(DataOrganizationSpecification specification, Map functions, Symbol hashSymbol, PlanNode source) { return new WindowNode( idAllocator.getNextId(), @@ -1385,7 +1385,7 @@ public RowNumberNode rowNumber(List partitionBy, Optional maxRo hashSymbol); } - public TopNRankingNode topNRanking(Specification specification, RankingType rankingType, int maxRankingPerPartition, Symbol rankingSymbol, Optional hashSymbol, PlanNode source) + public TopNRankingNode topNRanking(DataOrganizationSpecification specification, RankingType rankingType, int maxRankingPerPartition, Symbol rankingSymbol, Optional hashSymbol, PlanNode source) { return new TopNRankingNode( idAllocator.getNextId(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java index 27932bd49639..f593d7e766fa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestEliminateSorts.java @@ -26,7 +26,7 @@ import io.trino.sql.planner.iterative.IterativeOptimizer; import io.trino.sql.planner.iterative.rule.DetermineTableScanNodePartitioning; import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; -import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; @@ -47,7 +47,7 @@ public class TestEliminateSorts { private static final String QUANTITY_ALIAS = "QUANTITY"; - private static final ExpectedValueProvider windowSpec = specification( + private static final ExpectedValueProvider windowSpec = specification( ImmutableList.of(), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java index 06781c71f54a..e32bc1ea90b2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestMergeWindows.java @@ -25,6 +25,7 @@ import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.iterative.rule.GatherAndMergeWindows; import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.WindowNode; import io.trino.sql.tree.FrameBound; import io.trino.sql.tree.WindowFrame; @@ -95,8 +96,8 @@ public class TestMergeWindows private static final Optional UNSPECIFIED_FRAME = Optional.empty(); - private final ExpectedValueProvider specificationA; - private final ExpectedValueProvider specificationB; + private final ExpectedValueProvider specificationA; + private final ExpectedValueProvider specificationB; public TestMergeWindows() { @@ -305,12 +306,12 @@ public void testIdenticalWindowSpecificationsAAfilterA() @Test public void testIdenticalWindowSpecificationsDefaultFrame() { - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(ORDERKEY_ALIAS), ImmutableList.of(SHIPDATE_ALIAS), ImmutableMap.of(SHIPDATE_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -348,7 +349,7 @@ public void testMergeDifferentFrames() ImmutableList.of(), ImmutableList.of())); - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -394,7 +395,7 @@ public void testMergeDifferentFramesWithDefault() ImmutableList.of(), ImmutableList.of())); - ExpectedValueProvider specificationD = specification( + ExpectedValueProvider specificationD = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -436,12 +437,12 @@ public void testNotMergeAcrossJoinBranches() ")" + "SELECT * FROM foo, bar WHERE foo.a = bar.b"; - ExpectedValueProvider leftSpecification = specification( + ExpectedValueProvider leftSpecification = specification( ImmutableList.of(ORDERKEY_ALIAS), ImmutableList.of(SHIPDATE_ALIAS, QUANTITY_ALIAS), ImmutableMap.of(SHIPDATE_ALIAS, SortOrder.ASC_NULLS_LAST, QUANTITY_ALIAS, SortOrder.DESC_NULLS_LAST)); - ExpectedValueProvider rightSpecification = specification( + ExpectedValueProvider rightSpecification = specification( ImmutableList.of(rOrderkeyAlias), ImmutableList.of(rShipdateAlias, rQuantityAlias), ImmutableMap.of(rShipdateAlias, SortOrder.ASC_NULLS_LAST, rQuantityAlias, SortOrder.DESC_NULLS_LAST)); @@ -492,7 +493,7 @@ public void testNotMergeDifferentPartition() "SUM(quantity) over (PARTITION BY quantity ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(QUANTITY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -517,7 +518,7 @@ public void testNotMergeDifferentOrderBy() "SUM(quantity) OVER (PARTITION BY suppkey ORDER BY quantity ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_quantity_C " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(QUANTITY_ALIAS), ImmutableMap.of(QUANTITY_ALIAS, SortOrder.ASC_NULLS_LAST)); @@ -543,7 +544,7 @@ public void testNotMergeDifferentOrdering() "SUM(discount) over (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.DESC_NULLS_LAST)); @@ -570,7 +571,7 @@ public void testNotMergeDifferentNullOrdering() "SUM(discount) OVER (PARTITION BY suppkey ORDER BY orderkey ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) sum_discount_A " + "FROM lineitem"; - ExpectedValueProvider specificationC = specification( + ExpectedValueProvider specificationC = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), ImmutableMap.of(ORDERKEY_ALIAS, SortOrder.ASC_NULLS_FIRST)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java index 24a89e429bc7..cd0b5a380cd0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestReorderWindows.java @@ -25,7 +25,7 @@ import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.iterative.rule.GatherAndMergeWindows; import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; -import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.tree.WindowFrame; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; @@ -61,13 +61,13 @@ public class TestReorderWindows private static final Optional commonFrame; - private static final ExpectedValueProvider windowA; - private static final ExpectedValueProvider windowAp; - private static final ExpectedValueProvider windowApp; - private static final ExpectedValueProvider windowB; - private static final ExpectedValueProvider windowC; - private static final ExpectedValueProvider windowD; - private static final ExpectedValueProvider windowE; + private static final ExpectedValueProvider windowA; + private static final ExpectedValueProvider windowAp; + private static final ExpectedValueProvider windowApp; + private static final ExpectedValueProvider windowB; + private static final ExpectedValueProvider windowC; + private static final ExpectedValueProvider windowD; + private static final ExpectedValueProvider windowE; static { ImmutableMap.Builder columns = ImmutableMap.builder(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java index 9141b6d2d80e..d4058ee2cbb0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java @@ -28,7 +28,6 @@ import io.trino.sql.planner.plan.PatternRecognitionNode.Measure; import io.trino.sql.planner.plan.WindowNode.Frame; import io.trino.sql.planner.plan.WindowNode.Function; -import io.trino.sql.planner.plan.WindowNode.Specification; import io.trino.sql.planner.rowpattern.AggregatedSetDescriptor; import io.trino.sql.planner.rowpattern.AggregationValuePointer; import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers; @@ -197,7 +196,7 @@ public void testPatternRecognitionNodeRoundtrip() PatternRecognitionNode node = new PatternRecognitionNode( new PlanNodeId("0"), new ValuesNode(new PlanNodeId("1"), 1), - new Specification(ImmutableList.of(), Optional.empty()), + new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()), Optional.empty(), ImmutableSet.of(), 0, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java index ae1fc9cd712d..7013458d89ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java @@ -109,7 +109,7 @@ public void testSerializationRoundtrip() Optional.empty()); PlanNodeId id = newId(); - WindowNode.Specification specification = new WindowNode.Specification( + DataOrganizationSpecification specification = new DataOrganizationSpecification( ImmutableList.of(columnA), Optional.of(new OrderingScheme( ImmutableList.of(columnB), From 8bd17171a8469b9351e2fd7d9f2f49f4af9ea209 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sun, 18 Sep 2022 14:25:38 +0200 Subject: [PATCH 6/6] Plan table function invocation with table arguments --- .../io/trino/sql/planner/QueryPlanner.java | 2 +- .../io/trino/sql/planner/RelationPlanner.java | 116 +++++-- .../planner/optimizations/SymbolMapper.java | 2 +- .../UnaliasSymbolReferences.java | 31 +- .../sql/planner/plan/TableFunctionNode.java | 60 +++- .../sql/planner/planprinter/PlanPrinter.java | 81 ++++- .../sanity/ValidateDependenciesChecker.java | 36 +- .../connector/TestingTableFunctions.java | 39 +++ .../planner/TestTableFunctionInvocation.java | 151 +++++++++ .../planner/assertions/PlanMatchPattern.java | 7 + .../assertions/TableFunctionMatcher.java | 310 ++++++++++++++++++ 11 files changed, 780 insertions(+), 55 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java create mode 100644 core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 34decdc3b7e3..47034a65578b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -1414,7 +1414,7 @@ private List scopeAwareDistinct(PlanBuilder subPlan, L .collect(toImmutableList()); } - private static OrderingScheme translateOrderingScheme(List items, Function coercions) + public static OrderingScheme translateOrderingScheme(List items, Function coercions) { List coerced = items.stream() .map(SortItem::getSortKey) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 9a4c91a11cd6..f66a8b3f96cf 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import io.trino.Session; @@ -27,11 +28,13 @@ import io.trino.sql.ExpressionUtils; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.Analysis; +import io.trino.sql.analyzer.Analysis.TableArgumentAnalysis; import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; import io.trino.sql.analyzer.Analysis.UnnestAnalysis; import io.trino.sql.analyzer.Field; import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; +import io.trino.sql.planner.QueryPlanner.PlanAndMappings; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.DataOrganizationSpecification; @@ -88,9 +91,7 @@ import io.trino.sql.tree.SubqueryExpression; import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; -import io.trino.sql.tree.TableFunctionDescriptorArgument; import io.trino.sql.tree.TableFunctionInvocation; -import io.trino.sql.tree.TableFunctionTableArgument; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.Union; import io.trino.sql.tree.Unnest; @@ -106,6 +107,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -122,6 +124,7 @@ import static io.trino.sql.planner.QueryPlanner.extractPatternRecognitionExpressions; import static io.trino.sql.planner.QueryPlanner.planWindowSpecification; import static io.trino.sql.planner.QueryPlanner.pruneInvisibleFields; +import static io.trino.sql.planner.QueryPlanner.translateOrderingScheme; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -329,46 +332,99 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) @Override protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, Void context) { - node.getArguments().stream() - .forEach(argument -> { - if (argument.getValue() instanceof TableFunctionTableArgument) { - throw semanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions"); - } - if (argument.getValue() instanceof TableFunctionDescriptorArgument) { - throw semanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions"); - } - }); - TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); - // TODO handle input relations: - // 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested. - // 2. for each input relation, prepare the TableArgumentProperties record, consisting of: - // - row or set semantics (from the actualArgument) - // - prune when empty property (from the actualArgument) - // - pass through columns property (from the actualArgument) - // - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources) - // TODO add - argument name - // TODO add - mapping column name => Symbol // TODO mind the fields without names and duplicate field names in RelationType - List sources = ImmutableList.of(); - List inputRelationsProperties = ImmutableList.of(); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputSymbols = ImmutableList.builder(); - Scope scope = analysis.getScope(node); - // TODO pass columns from input relations, and make sure they have the right qualifier - List outputSymbols = scope.getRelationType().getAllFields().stream() + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) .map(symbolAllocator::newSymbol) .collect(toImmutableList()); + outputSymbols.addAll(properOutputs); + + // process sources in order of argument declarations + for (TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); + PlanBuilder sourcePlanBuilder = newPlanBuilder(sourcePlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); + + // map column names to symbols + // note: hidden columns are included in the mapping. They are present both in sourceDescriptor.allFields, and in sourcePlan.fieldMappings + // note: for an aliased relation or a CTE, the field names in the relation type are in the same case as specified in the alias. + // quotes and canonicalization rules are not applied. + ImmutableMultimap.Builder columnMapping = ImmutableMultimap.builder(); + RelationType sourceDescriptor = sourcePlan.getDescriptor(); + for (int i = 0; i < sourceDescriptor.getAllFieldCount(); i++) { + Optional name = sourceDescriptor.getFieldByIndex(i).getName(); + if (name.isPresent()) { + columnMapping.put(name.get(), sourcePlan.getSymbol(i)); + } + } + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + List partitioningColumns = tableArgument.getPartitionBy().get(); + PlanAndMappings copartitionCoercions = coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, symbolAllocator, typeCoercion); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } + + // order by + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + // the ordering symbols are not coerced + orderBy = Optional.of(translateOrderingScheme(tableArgument.getOrderBy().get().getSortItems(), sourcePlanBuilder::translate)); + } + + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + sources.add(sourcePlanBuilder.getRoot()); + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + columnMapping.build(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + tableArgument.isPassThroughColumns(), + specification)); + + // add output symbols passed from the table argument + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputSymbols.addAll(sourcePlan.getFieldMappings()); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + // the original symbols for partitioning columns, not coerced + .map(sourcePlanBuilder::translate) + .forEach(outputSymbols::add); + } + } + PlanNode root = new TableFunctionNode( idAllocator.getNextId(), functionAnalysis.getFunctionName(), functionAnalysis.getArguments(), - outputSymbols, - sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), - inputRelationsProperties, + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), new TableFunctionHandle(functionAnalysis.getCatalogHandle(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); - return new RelationPlan(root, scope, outputSymbols, outerContext); + return new RelationPlan(root, analysis.getScope(node), outputSymbols.build(), outerContext); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 5048da2d530e..a8f379b57d2d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -240,7 +240,7 @@ private WindowNode.Frame map(WindowNode.Frame frame) frame.getOriginalEndValue()); } - private DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) + public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) { return new DataOrganizationSpecification( mapAndDistinct(specification.getPartitionBy()), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 35562a189feb..ac302ca75719 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; @@ -39,6 +40,7 @@ import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; +import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -76,6 +78,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -321,20 +324,42 @@ public PlanAndMappings visitPatternRecognition(PatternRecognitionNode node, Unal @Override public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext context) { - // TODO rewrite sources, and tableArgumentProperties when we add support for input tables Map mapping = new HashMap<>(context.getCorrelationMapping()); SymbolMapper mapper = symbolMapper(mapping); List newProperOutputs = mapper.map(node.getProperOutputs()); + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanAndMappings newSource = node.getSources().get(i).accept(this, context); + newSources.add(newSource.getRoot()); + + SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); + TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); + properties.getColumnMapping().entries().stream() + .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); + Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + newTableArgumentProperties.add(new TableArgumentProperties( + properties.getArgumentName(), + newColumnMapping.build(), + properties.isRowSemantics(), + properties.isPruneWhenEmpty(), + properties.isPassThroughColumns(), + newSpecification)); + } + return new PlanAndMappings( new TableFunctionNode( node.getId(), node.getName(), node.getArguments(), newProperOutputs, - node.getSources(), - node.getTableArgumentProperties(), + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), node.getHandle()), mapping); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index ad1b20e26426..924d88960693 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; import io.trino.metadata.TableFunctionHandle; import io.trino.spi.ptf.Argument; import io.trino.sql.planner.Symbol; @@ -25,8 +27,10 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @Immutable @@ -38,6 +42,7 @@ public class TableFunctionNode private final List properOutputs; private final List sources; private final List tableArgumentProperties; + private final List> copartitioningLists; private final TableFunctionHandle handle; @JsonCreator @@ -48,6 +53,7 @@ public TableFunctionNode( @JsonProperty("properOutputs") List properOutputs, @JsonProperty("sources") List sources, @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, @JsonProperty("handle") TableFunctionHandle handle) { super(id); @@ -56,6 +62,9 @@ public TableFunctionNode( this.properOutputs = ImmutableList.copyOf(properOutputs); this.sources = ImmutableList.copyOf(sources); this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); this.handle = requireNonNull(handle, "handle is null"); } @@ -83,6 +92,12 @@ public List getTableArgumentProperties() return tableArgumentProperties; } + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -99,8 +114,23 @@ public List getSources() @Override public List getOutputSymbols() { - // TODO add outputs from input relations - return properOutputs; + ImmutableList.Builder symbols = ImmutableList.builder(); + + symbols.addAll(properOutputs); + + for (int i = 0; i < sources.size(); i++) { + TableArgumentProperties sourceProperties = tableArgumentProperties.get(i); + if (sourceProperties.isPassThroughColumns()) { + symbols.addAll(sources.get(i).getOutputSymbols()); + } + else { + sourceProperties.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .ifPresent(symbols::addAll); + } + } + + return symbols.build(); } @Override @@ -113,29 +143,47 @@ public R accept(PlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { checkArgument(sources.size() == newSources.size(), "wrong number of new children"); - return new TableFunctionNode(getId(), name, arguments, properOutputs, newSources, tableArgumentProperties, handle); + return new TableFunctionNode(getId(), name, arguments, properOutputs, newSources, tableArgumentProperties, copartitioningLists, handle); } public static class TableArgumentProperties { + private final String argumentName; + private final Multimap columnMapping; private final boolean rowSemantics; private final boolean pruneWhenEmpty; private final boolean passThroughColumns; - private final DataOrganizationSpecification specification; + private final Optional specification; @JsonCreator public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, + @JsonProperty("columnMapping") Multimap columnMapping, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughColumns") boolean passThroughColumns, - @JsonProperty("specification") DataOrganizationSpecification specification) + @JsonProperty("specification") Optional specification) { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); + this.columnMapping = ImmutableMultimap.copyOf(columnMapping); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; this.specification = requireNonNull(specification, "specification is null"); } + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + + @JsonProperty + public Multimap getColumnMapping() + { + return columnMapping; + } + @JsonProperty public boolean isRowSemantics() { @@ -155,7 +203,7 @@ public boolean isPassThroughColumns() } @JsonProperty - public DataOrganizationSpecification getSpecification() + public Optional getSpecification() { return specification; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 4ddbf2647e05..1af71aeb6080 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -42,6 +42,8 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.DescriptorArgument; import io.trino.spi.ptf.ScalarArgument; import io.trino.spi.statistics.ColumnStatisticMetadata; import io.trino.spi.statistics.TableStatisticType; @@ -104,6 +106,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -154,6 +157,7 @@ import static io.trino.execution.StageInfo.getAllStages; import static io.trino.metadata.ResolvedFunction.extractFunctionName; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; +import static io.trino.spi.ptf.DescriptorArgument.NULL_DESCRIPTOR; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.ExpressionUtils.combineConjunctsWithDuplicates; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -1738,24 +1742,77 @@ public Void visitTableFunction(TableFunctionNode node, Context context) ImmutableMap.of("name", node.getName()), context.tag()); - checkArgument( - node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), - "Table or descriptor arguments are not yet supported in PlanPrinter"); + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); - node.getArguments().entrySet().stream() - .forEach(entry -> nodeOutput.appendDetails(entry.getKey() + " => " + formatArgument((ScalarArgument) entry.getValue()))); + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableArgumentProperties::getArgumentName, identity())); + + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetails(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetails(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(joining(", ", "Co-partition: [", "]"))); + } + } + + for (int i = 0; i < node.getSources().size(); i++) { + node.getSources().get(i).accept(this, new Context(node.getTableArgumentProperties().get(i).getArgumentName())); + } return null; } - private String formatArgument(ScalarArgument argument) + private String formatArgument(String argumentName, Argument argument, Map tableArguments) { - return format( - "ScalarArgument{type=%s, value=%s}", - argument.getType(), - anonymizer.anonymize( - argument.getType(), - valuePrinter.castToVarchar(argument.getType(), argument.getValue()))); + if (argument instanceof ScalarArgument scalarArgument) { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + scalarArgument.getType().getDisplayName(), + anonymizer.anonymize( + scalarArgument.getType(), + valuePrinter.castToVarchar(scalarArgument.getType(), scalarArgument.getValue()))); + } + if (argument instanceof DescriptorArgument descriptorArgument) { + String descriptor; + if (descriptorArgument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = descriptorArgument.getDescriptor().orElseThrow().getFields().stream() + .map(field -> anonymizer.anonymizeColumn(field.getName()) + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + else { + TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + StringBuilder properties = new StringBuilder(); + if (argumentProperties.isRowSemantics()) { + properties.append("row semantics"); + } + argumentProperties.getSpecification().ifPresent(specification -> { + properties + .append("partition by: [") + .append(Joiner.on(", ").join(anonymize(specification.getPartitionBy()))) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + properties + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + }); + if (argumentProperties.isPruneWhenEmpty()) { + properties.append(", prune when empty"); + } + if (argumentProperties.isPassThroughColumns()) { + properties.append(", pass through columns"); + } + return format("%s => TableArgument{%s}", argumentName, properties); + } } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 734bbad11c59..fc3631cbd2a1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -218,9 +218,41 @@ public Void visitPatternRecognition(PatternRecognitionNode node, Set bou } @Override - public Void visitTableFunction(TableFunctionNode node, Set context) + public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) { - // TODO + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getColumnMapping().values(), + "Invalid node. Input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getColumnMapping().values(), + source.getOutputSymbols()); + argumentProperties.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputSymbols()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderBy(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputSymbols()); + }); + }); + } + return null; } diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index 6869f8cae3b0..b5ba197c029c 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -329,6 +329,45 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact } } + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + "different_arguments_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return ANALYSIS; + } + } + public static class TestingTableFunctionHandle implements ConnectorTableFunctionHandle { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 000000000000..6b6f7f320ffa --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,151 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; +import io.trino.connector.TestingTableFunctions.DescriptorArgumentFunction; +import io.trino.connector.TestingTableFunctions.DifferentArgumentTypesFunction; +import io.trino.connector.TestingTableFunctions.TestingTableFunctionHandle; +import io.trino.connector.TestingTableFunctions.TwoScalarArgumentsFunction; +import io.trino.spi.connector.TableFunctionApplicationResult; +import io.trino.spi.ptf.Descriptor; +import io.trino.spi.ptf.Descriptor.Field; +import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.tree.LongLiteral; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.sql.planner.LogicalPlanner.Stage.CREATED; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static io.trino.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static io.trino.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "mock"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new DescriptorArgumentFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow())); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "mock", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + """ + SELECT * FROM TABLE(mock.system.different_arguments_function( + INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1, + INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3, + INPUT_2 => TABLE(VALUES 1) t2(c2), + ID => BIGINT '2001', + LAYOUT => DESCRIPTOR (x boolean, y bigint) + COPARTITION (t1, t3))) t + """, + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty()) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(mock.system.two_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(mock.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(mock.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 16ead964c9e5..185653a1546c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -835,6 +835,13 @@ public static PlanMatchPattern tableExecute(List columns, List c return node(TableExecuteNode.class, source).with(new TableExecuteMatcher(columns, columnNames)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 000000000000..322c9fbc1c87 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,310 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.spi.ptf.Argument; +import io.trino.spi.ptf.Descriptor; +import io.trino.spi.ptf.DescriptorArgument; +import io.trino.spi.ptf.ScalarArgument; +import io.trino.spi.ptf.TableArgument; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.tree.SymbolReference; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; +import static io.trino.sql.planner.assertions.MatchResult.match; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + if (expected instanceof DescriptorArgumentValue expectedDescriptor) { + if (!(actual instanceof DescriptorArgument actualDescriptor) || !expectedDescriptor.descriptor().equals(actualDescriptor.getDescriptor())) { + return NO_MATCH; + } + } + else if (expected instanceof ScalarArgumentValue expectedScalar) { + if (!(actual instanceof ScalarArgument actualScalar) || !Objects.equals(expectedScalar.value(), actualScalar.getValue())) { + return NO_MATCH; + } + } + else { + if (!(actual instanceof TableArgument)) { + return NO_MATCH; + } + TableArgumentValue expectedTableArgument = (TableArgumentValue) expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.isPassThroughColumns()) { + return NO_MATCH; + } + boolean specificationMatches = expectedTableArgument.specification() + .map(specification -> specification.getExpectedValue(symbolAliases)) + .equals(argumentProperties.getSpecification()); + if (!specificationMatches) { + return NO_MATCH; + } + } + } + + if (properOutputs.size() != tableFunctionNode.getProperOutputs().size()) { + return NO_MATCH; + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), tableFunctionNode.getProperOutputs().get(i).toSymbolReference()); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + public sealed interface ArgumentValue + permits DescriptorArgumentValue, ScalarArgumentValue, TableArgumentValue + {} + + public record DescriptorArgumentValue(Optional descriptor) + implements ArgumentValue + { + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + } + + public record ScalarArgumentValue(Object value) + implements ArgumentValue + {} + + public record TableArgumentValue( + int sourceIndex, + boolean rowSemantics, + boolean pruneWhenEmpty, + boolean passThroughColumns, + Optional> specification) + implements ArgumentValue + { + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification); + } + } + } +}