From 6b01cc139ce9ff1b6cf9842659258e937da1517b Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sun, 26 Jul 2020 12:22:10 +0200 Subject: [PATCH] Support recursive CTE --- .../io/prestosql/SystemSessionProperties.java | 15 + .../io/prestosql/sql/analyzer/Analysis.java | 35 ++ .../sql/analyzer/FeaturesConfig.java | 13 + .../java/io/prestosql/sql/analyzer/Scope.java | 5 + .../sql/analyzer/StatementAnalyzer.java | 285 +++++++++- .../prestosql/sql/planner/LogicalPlanner.java | 4 +- .../sql/planner/NodeAndMappings.java | 42 ++ .../io/prestosql/sql/planner/PlanCopier.java | 221 ++++++++ .../prestosql/sql/planner/QueryPlanner.java | 220 +++++++- .../sql/planner/RelationPlanner.java | 30 +- .../sql/planner/SubqueryPlanner.java | 8 +- .../rule/PushTableWriteThroughUnion.java | 3 +- .../optimizations/PlanNodeDecorrelator.java | 3 +- .../planner/optimizations/SymbolMapper.java | 41 +- .../UnaliasSymbolReferences.java | 234 +++++---- .../prestosql/sql/analyzer/TestAnalyzer.java | 488 +++++++++++++++++- .../sql/analyzer/TestFeaturesConfig.java | 3 + .../sql/planner/TestRecursiveCTE.java | 98 ++++ .../prestosql/sql/query/TestRecursiveCTE.java | 267 ++++++++++ .../io/prestosql/spi/StandardErrorCode.java | 4 + .../tests/AbstractTestEngineOnlyQueries.java | 4 +- 21 files changed, 1865 insertions(+), 158 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/NodeAndMappings.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/PlanCopier.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestRecursiveCTE.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/query/TestRecursiveCTE.java diff --git a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java index 0b4bcc4f51b0..caa732948b3e 100644 --- a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java +++ b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java @@ -106,6 +106,7 @@ public final class SystemSessionProperties public static final String FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = "filter_and_project_min_output_page_size"; public static final String FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_ROW_COUNT = "filter_and_project_min_output_page_row_count"; public static final String DISTRIBUTED_SORT = "distributed_sort"; + public static final String MAX_RECURSION_DEPTH = "max_recursion_depth"; public static final String USE_MARK_DISTINCT = "use_mark_distinct"; public static final String PREFER_PARTIAL_AGGREGATION = "prefer_partial_aggregation"; public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number"; @@ -448,6 +449,15 @@ public SystemSessionProperties( "Parallelize sort across multiple nodes", featuresConfig.isDistributedSortEnabled(), false), + new PropertyMetadata<>( + MAX_RECURSION_DEPTH, + "Maximum recursion depth for recursive common table expression", + INTEGER, + Integer.class, + featuresConfig.getMaxRecursionDepth(), + false, + value -> validateIntegerValue(value, MAX_RECURSION_DEPTH, 1, false), + object -> object), booleanProperty( USE_MARK_DISTINCT, "Implement DISTINCT aggregations using MarkDistinct", @@ -890,6 +900,11 @@ public static boolean isDistributedSortEnabled(Session session) return session.getSystemProperty(DISTRIBUTED_SORT, Boolean.class); } + public static int getMaxRecursionDepth(Session session) + { + return session.getSystemProperty(MAX_RECURSION_DEPTH, Integer.class); + } + public static int getMaxGroupingSets(Session session) { return session.getSystemProperty(MAX_GROUPING_SETS, Integer.class); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java index 8581bcfb1cb9..04c27c2eff3e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java @@ -102,6 +102,12 @@ public class Analysis private final Map, Query> namedQueries = new LinkedHashMap<>(); + // map expandable query to the node being the inner recursive reference + private final Map, Node> expandableNamedQueries = new LinkedHashMap<>(); + + // map inner recursive reference in the expandable query to the recursion base scope + private final Map, Scope> expandableBaseScopes = new LinkedHashMap<>(); + // Synthetic scope when a query does not have a FROM clause // We need to track this separately because there's no node we can attach it to. private final Map, Scope> implicitFromScopes = new LinkedHashMap<>(); @@ -614,6 +620,35 @@ public void registerNamedQuery(Table tableReference, Query query) namedQueries.put(NodeRef.of(tableReference), query); } + public void registerExpandableQuery(Query query, Node recursiveReference) + { + requireNonNull(query, "query is null"); + requireNonNull(recursiveReference, "recursiveReference is null"); + + expandableNamedQueries.put(NodeRef.of(query), recursiveReference); + } + + public boolean isExpandableQuery(Query query) + { + return expandableNamedQueries.containsKey(NodeRef.of(query)); + } + + public Node getRecursiveReference(Query query) + { + checkArgument(isExpandableQuery(query), "query is not registered as expandable"); + return expandableNamedQueries.get(NodeRef.of(query)); + } + + public void setExpandableBaseScope(Node node, Scope scope) + { + expandableBaseScopes.put(NodeRef.of(node), scope); + } + + public Optional getExpandableBaseScope(Node node) + { + return Optional.ofNullable(expandableBaseScopes.get(NodeRef.of(node))); + } + public void registerTableForView(Table tableReference) { tablesForView.push(requireNonNull(tableReference, "table is null")); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java index 504244a69f99..ecc1b2d30af7 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java @@ -96,6 +96,7 @@ public class FeaturesConfig private boolean pagesIndexEagerCompactionEnabled; private boolean distributedSort = true; private boolean omitDateTimeTypePrecision; + private int maxRecursionDepth = 10; private boolean dictionaryAggregation; @@ -1003,6 +1004,18 @@ public FeaturesConfig setDistributedSortEnabled(boolean enabled) return this; } + public int getMaxRecursionDepth() + { + return maxRecursionDepth; + } + + @Config("max-recursion-depth") + public FeaturesConfig setMaxRecursionDepth(int maxRecursionDepth) + { + this.maxRecursionDepth = maxRecursionDepth; + return this; + } + public int getMaxGroupingSets() { return maxGroupingSets; diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java index ea4deac592d3..ba57a5a5a78b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java @@ -73,6 +73,11 @@ private Scope( this.namedQueries = ImmutableMap.copyOf(requireNonNull(namedQueries, "namedQueries is null")); } + public Scope withRelationType(RelationType relationType) + { + return new Scope(parent, queryBoundary, relationId, relationType, namedQueries); + } + public Scope getQueryBoundaryScope() { Scope scope = this; diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java index df0a5124ca04..cccecd6fa6a8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java @@ -147,6 +147,7 @@ import io.prestosql.sql.tree.SubscriptExpression; import io.prestosql.sql.tree.Table; import io.prestosql.sql.tree.TableSubquery; +import io.prestosql.sql.tree.Union; import io.prestosql.sql.tree.Unnest; import io.prestosql.sql.tree.Use; import io.prestosql.sql.tree.Values; @@ -164,6 +165,7 @@ import java.util.Optional; import java.util.OptionalLong; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -190,13 +192,17 @@ import static io.prestosql.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.prestosql.spi.StandardErrorCode.INVALID_COLUMN_REFERENCE; import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.prestosql.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; +import static io.prestosql.spi.StandardErrorCode.INVALID_RECURSIVE_REFERENCE; import static io.prestosql.spi.StandardErrorCode.INVALID_ROW_FILTER; import static io.prestosql.spi.StandardErrorCode.INVALID_VIEW; import static io.prestosql.spi.StandardErrorCode.INVALID_WINDOW_FRAME; import static io.prestosql.spi.StandardErrorCode.MISMATCHED_COLUMN_ALIASES; +import static io.prestosql.spi.StandardErrorCode.MISSING_COLUMN_ALIASES; import static io.prestosql.spi.StandardErrorCode.MISSING_COLUMN_NAME; import static io.prestosql.spi.StandardErrorCode.MISSING_GROUP_BY; import static io.prestosql.spi.StandardErrorCode.MISSING_ORDER_BY; +import static io.prestosql.spi.StandardErrorCode.NESTED_RECURSIVE; import static io.prestosql.spi.StandardErrorCode.NESTED_WINDOW; import static io.prestosql.spi.StandardErrorCode.NOT_FOUND; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; @@ -240,8 +246,10 @@ import static io.prestosql.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; import static io.prestosql.sql.tree.Join.Type.FULL; import static io.prestosql.sql.tree.Join.Type.INNER; +import static io.prestosql.sql.tree.Join.Type.LEFT; import static io.prestosql.sql.tree.Join.Type.RIGHT; import static io.prestosql.sql.tree.WindowFrame.Type.RANGE; +import static io.prestosql.sql.util.AstUtils.preOrder; import static io.prestosql.type.UnknownType.UNKNOWN; import static io.prestosql.util.MoreLists.mappedCopy; import static java.lang.Math.toIntExact; @@ -1022,6 +1030,17 @@ protected Scope visitTable(Table table, Optional scope) if (withQuery.isPresent()) { return createScopeForCommonTableExpression(table, scope, withQuery.get()); } + // is this a recursive reference in expandable WITH query? If so, there's base scope recorded. + Optional expandableBaseScope = analysis.getExpandableBaseScope(table); + if (expandableBaseScope.isPresent()) { + Scope baseScope = expandableBaseScope.get(); + // adjust local and outer parent scopes accordingly to the local context of the recursive reference + Scope resultScope = scopeBuilder(scope) + .withRelationType(baseScope.getRelationId(), baseScope.getRelationType()) + .build(); + analysis.setScope(table, resultScope); + return resultScope; + } } QualifiedObjectName name = createQualifiedObjectName(session, table, table.getName()); @@ -2515,42 +2534,276 @@ private List descriptorToFields(Scope scope) private Scope analyzeWith(Query node, Optional scope) { - // analyze WITH clause if (node.getWith().isEmpty()) { return createScope(scope); } - With with = node.getWith().get(); - if (with.isRecursive()) { - throw semanticException(NOT_SUPPORTED, with, "Recursive WITH queries are not supported"); - } + // analyze WITH clause + With with = node.getWith().get(); Scope.Builder withScopeBuilder = scopeBuilder(scope); - for (WithQuery withQuery : with.getQueries()) { - Query query = withQuery.getQuery(); - process(query, withScopeBuilder.build()); + for (WithQuery withQuery : with.getQueries()) { String name = withQuery.getName().getValue().toLowerCase(ENGLISH); if (withScopeBuilder.containsNamedQuery(name)) { throw semanticException(DUPLICATE_NAMED_QUERY, withQuery, "WITH query name '%s' specified more than once", name); } - // check if all or none of the columns are explicitly alias - if (withQuery.getColumnNames().isPresent()) { - List columnNames = withQuery.getColumnNames().get(); - RelationType queryDescriptor = analysis.getOutputDescriptor(query); - if (columnNames.size() != queryDescriptor.getVisibleFieldCount()) { - throw semanticException(MISMATCHED_COLUMN_ALIASES, withQuery, "WITH column alias list has %s entries but WITH query(%s) has %s columns", columnNames.size(), name, queryDescriptor.getVisibleFieldCount()); + boolean isRecursive = false; + if (with.isRecursive()) { + isRecursive = tryProcessRecursiveQuery(withQuery, name, withScopeBuilder); + // WITH query is not shaped accordingly to the rules for expandable query and will be processed like a plain WITH query. + // Since RECURSIVE is specified, any reference to WITH query name is considered a recursive reference and is not allowed. + if (!isRecursive) { + List recursiveReferences = findReferences(withQuery.getQuery(), withQuery.getName()); + if (!recursiveReferences.isEmpty()) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, recursiveReferences.get(0), "recursive reference not allowed in this context"); + } } } - withScopeBuilder.withNamedQuery(name, withQuery); - } + if (!isRecursive) { + Query query = withQuery.getQuery(); + process(query, withScopeBuilder.build()); + // check if all or none of the columns are explicitly alias + if (withQuery.getColumnNames().isPresent()) { + validateColumnAliases(withQuery.getColumnNames().get(), analysis.getOutputDescriptor(query).getVisibleFieldCount()); + } + + withScopeBuilder.withNamedQuery(name, withQuery); + } + } Scope withScope = withScopeBuilder.build(); analysis.setScope(with, withScope); return withScope; } + private boolean tryProcessRecursiveQuery(WithQuery withQuery, String name, Scope.Builder withScopeBuilder) + { + if (withQuery.getColumnNames().isEmpty()) { + throw semanticException(MISSING_COLUMN_ALIASES, withQuery, "missing column aliases in recursive WITH query"); + } + preOrder(withQuery.getQuery()) + .filter(child -> child instanceof With && ((With) child).isRecursive()) + .findFirst() + .ifPresent(child -> { + throw semanticException(NESTED_RECURSIVE, child, "nested recursive WITH query"); + }); + // if RECURSIVE is specified, all queries in the WITH list are considered potentially recursive + // try resolve WITH query as expandable query + // a) validate shape of the query and location of recursive reference + if (!(withQuery.getQuery().getQueryBody() instanceof Union)) { + return false; + } + Union union = (Union) withQuery.getQuery().getQueryBody(); + if (union.getRelations().size() != 2) { + return false; + } + Relation anchor = union.getRelations().get(0); + Relation step = union.getRelations().get(1); + List anchorReferences = findReferences(anchor, withQuery.getName()); + if (!anchorReferences.isEmpty()) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, anchorReferences.get(0), "WITH table name is referenced in the base relation of recursion"); + } + // a WITH query is linearly recursive if it has a single recursive reference + List stepReferences = findReferences(step, withQuery.getName()); + if (stepReferences.size() > 1) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, stepReferences.get(1), "multiple recursive references in the step relation of recursion"); + } + if (stepReferences.size() != 1) { + return false; + } + // search for QuerySpecification in parenthesized subquery + Relation specification = step; + while (specification instanceof TableSubquery) { + Query query = ((TableSubquery) specification).getQuery(); + query.getLimit().ifPresent(limit -> { + throw semanticException(INVALID_LIMIT_CLAUSE, limit, "FETCH FIRST / LIMIT clause in the step relation of recursion"); + }); + specification = query.getQueryBody(); + } + if (!(specification instanceof QuerySpecification) || ((QuerySpecification) specification).getFrom().isEmpty()) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, stepReferences.get(0), "recursive reference outside of FROM clause of the step relation of recursion"); + } + Relation from = ((QuerySpecification) specification).getFrom().get(); + List fromReferences = findReferences(from, withQuery.getName()); + if (fromReferences.size() == 0) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, stepReferences.get(0), "recursive reference outside of FROM clause of the step relation of recursion"); + } + + // b) validate top-level shape of recursive query + withQuery.getQuery().getWith().ifPresent(innerWith -> { + throw semanticException(NOT_SUPPORTED, innerWith, "immediate WITH clause in recursive query"); + }); + withQuery.getQuery().getOrderBy().ifPresent(orderBy -> { + throw semanticException(NOT_SUPPORTED, orderBy, "immediate ORDER BY clause in recursive query"); + }); + withQuery.getQuery().getOffset().ifPresent(offset -> { + throw semanticException(NOT_SUPPORTED, offset, "immediate OFFSET clause in recursive query"); + }); + withQuery.getQuery().getLimit().ifPresent(limit -> { + throw semanticException(INVALID_LIMIT_CLAUSE, limit, "immediate FETCH FIRST / LIMIT clause in recursive query"); + }); + + // c) validate recursion step has no illegal clauses + validateFromClauseOfRecursiveTerm(from, withQuery.getName()); + + // shape validation complete - process query as expandable query + Scope parentScope = withScopeBuilder.build(); + // process expandable query -- anchor + Scope anchorScope = process(anchor, parentScope); + // set aliases in anchor scope as defined for WITH query. Recursion step will refer to anchor fields by aliases. + Scope aliasedAnchorScope = setAliases(anchorScope, withQuery.getName(), withQuery.getColumnNames().get()); + // record expandable query base scope for recursion step analysis + Node recursiveReference = fromReferences.get(0); + analysis.setExpandableBaseScope(recursiveReference, aliasedAnchorScope); + // process expandable query -- recursion step + Scope stepScope = process(step, parentScope); + + // verify anchor and step have matching descriptors + RelationType anchorType = aliasedAnchorScope.getRelationType().withOnlyVisibleFields(); + RelationType stepType = stepScope.getRelationType().withOnlyVisibleFields(); + if (anchorType.getVisibleFieldCount() != stepType.getVisibleFieldCount()) { + throw semanticException(TYPE_MISMATCH, step, "base and step relations of recursion have different number of fields: %s, %s", anchorType.getVisibleFieldCount(), stepType.getVisibleFieldCount()); + } + + List anchorFieldTypes = anchorType.getVisibleFields().stream() + .map(Field::getType) + .collect(toImmutableList()); + List stepFieldTypes = stepType.getVisibleFields().stream() + .map(Field::getType) + .collect(toImmutableList()); + + for (int i = 0; i < anchorFieldTypes.size(); i++) { + if (!typeCoercion.canCoerce(stepFieldTypes.get(i), anchorFieldTypes.get(i))) { + // TODO for more precise error location, pass the mismatching select expression instead of `step` + throw semanticException( + TYPE_MISMATCH, + step, + "recursion step relation output type (%s) is not coercible to recursion base relation output type (%s) at column %s", + stepFieldTypes.get(i), + anchorFieldTypes.get(i), + i + 1); + } + } + + if (!anchorFieldTypes.equals(stepFieldTypes)) { + analysis.addRelationCoercion(step, anchorFieldTypes.toArray(Type[]::new)); + } + + analysis.setScope(withQuery.getQuery(), aliasedAnchorScope); + analysis.registerExpandableQuery(withQuery.getQuery(), recursiveReference); + withScopeBuilder.withNamedQuery(name, withQuery); + return true; + } + + private List findReferences(Node node, Identifier name) + { + Stream allReferences = preOrder(node) + .filter(isTableWithName(name)); + + // TODO: recursive references could be supported in subquery before the point of shadowing. + //currently, the recursive query name is considered shadowed in the whole subquery if the subquery defines a common table with the same name + Set shadowedReferences = preOrder(node) + .filter(isQueryWithNameShadowed(name)) + .flatMap(query -> preOrder(query) + .filter(isTableWithName(name))) + .collect(toImmutableSet()); + + return allReferences + .filter(reference -> !shadowedReferences.contains(reference)) + .collect(toImmutableList()); + } + + private Predicate isTableWithName(Identifier name) + { + return node -> { + if (!(node instanceof Table)) { + return false; + } + Table table = (Table) node; + QualifiedName tableName = table.getName(); + return tableName.getPrefix().isEmpty() && tableName.hasSuffix(QualifiedName.of(name.getValue())); + }; + } + + private Predicate isQueryWithNameShadowed(Identifier name) + { + return node -> { + if (!(node instanceof Query)) { + return false; + } + Query query = (Query) node; + if (query.getWith().isEmpty()) { + return false; + } + return query.getWith().get().getQueries().stream() + .map(WithQuery::getName) + .map(Identifier::getValue) + .anyMatch(withQueryName -> withQueryName.equalsIgnoreCase(name.getValue())); + }; + } + + private void validateFromClauseOfRecursiveTerm(Relation from, Identifier name) + { + preOrder(from) + .filter(node -> node instanceof Join) + .forEach(node -> { + Join join = (Join) node; + Join.Type type = join.getType(); + if (type == LEFT || type == RIGHT || type == FULL) { + List leftRecursiveReferences = findReferences(join.getLeft(), name); + List rightRecursiveReferences = findReferences(join.getRight(), name); + if (!leftRecursiveReferences.isEmpty() && (type == RIGHT || type == FULL)) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, leftRecursiveReferences.get(0), "recursive reference in left source of %s join", type); + } + if (!rightRecursiveReferences.isEmpty() && (type == LEFT || type == FULL)) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, rightRecursiveReferences.get(0), "recursive reference in right source of %s join", type); + } + } + }); + + preOrder(from) + .filter(node -> node instanceof Intersect && !((Intersect) node).isDistinct()) + .forEach(node -> { + Intersect intersect = (Intersect) node; + intersect.getRelations().stream() + .flatMap(relation -> findReferences(relation, name).stream()) + .findFirst() + .ifPresent(reference -> { + throw semanticException(INVALID_RECURSIVE_REFERENCE, reference, "recursive reference in INTERSECT ALL"); + }); + }); + + preOrder(from) + .filter(node -> node instanceof Except) + .forEach(node -> { + Except except = (Except) node; + List rightRecursiveReferences = findReferences(except.getRight(), name); + if (!rightRecursiveReferences.isEmpty()) { + throw semanticException( + INVALID_RECURSIVE_REFERENCE, + rightRecursiveReferences.get(0), + "recursive reference in right relation of EXCEPT %s", + except.isDistinct() ? "DISTINCT" : "ALL"); + } + if (!except.isDistinct()) { + List leftRecursiveReferences = findReferences(except.getLeft(), name); + if (!leftRecursiveReferences.isEmpty()) { + throw semanticException(INVALID_RECURSIVE_REFERENCE, leftRecursiveReferences.get(0), "recursive reference in left relation of EXCEPT ALL"); + } + } + }); + } + + private Scope setAliases(Scope scope, Identifier tableName, List columnNames) + { + RelationType oldDescriptor = scope.getRelationType(); + validateColumnAliases(columnNames, oldDescriptor.getVisibleFieldCount()); + RelationType newDescriptor = oldDescriptor.withAlias(tableName.getValue(), columnNames.stream().map(Identifier::getValue).collect(toImmutableList())); + return scope.withRelationType(newDescriptor); + } + private void verifySelectDistinct(QuerySpecification node, List orderByExpressions, List outputExpressions, Scope sourceScope, Scope orderByScope) { Set> aliases = getAliases(node.getSelect()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java index 6e1604931980..caec018e465f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java @@ -568,7 +568,7 @@ private Expression noTruncationCast(Expression expression, Type fromType, Type t private RelationPlan createDeletePlan(Analysis analysis, Delete node) { - DeleteNode deleteNode = new QueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session) + DeleteNode deleteNode = new QueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session, ImmutableMap.of()) .plan(node); TableFinishNode commitNode = new TableFinishNode( @@ -605,7 +605,7 @@ private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) private RelationPlan createRelationPlan(Analysis analysis, Query query) { - return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session) + return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session, ImmutableMap.of()) .process(query, null); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/NodeAndMappings.java b/presto-main/src/main/java/io/prestosql/sql/planner/NodeAndMappings.java new file mode 100644 index 000000000000..c7bca3f17929 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/NodeAndMappings.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import io.prestosql.sql.planner.plan.PlanNode; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class NodeAndMappings +{ + private final PlanNode node; + private final List fields; + + public NodeAndMappings(PlanNode node, List fields) + { + this.node = requireNonNull(node, "node is null"); + this.fields = requireNonNull(fields, "fields is null"); + } + + public PlanNode getNode() + { + return node; + } + + public List getFields() + { + return fields; + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanCopier.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanCopier.java new file mode 100644 index 000000000000..ac7dd2fa1419 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanCopier.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import io.prestosql.metadata.Metadata; +import io.prestosql.sql.planner.optimizations.UnaliasSymbolReferences; +import io.prestosql.sql.planner.plan.AggregationNode; +import io.prestosql.sql.planner.plan.ApplyNode; +import io.prestosql.sql.planner.plan.CorrelatedJoinNode; +import io.prestosql.sql.planner.plan.EnforceSingleRowNode; +import io.prestosql.sql.planner.plan.ExceptNode; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.GroupIdNode; +import io.prestosql.sql.planner.plan.IntersectNode; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.LimitNode; +import io.prestosql.sql.planner.plan.OffsetNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SampleNode; +import io.prestosql.sql.planner.plan.SimplePlanRewriter; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.TopNNode; +import io.prestosql.sql.planner.plan.UnionNode; +import io.prestosql.sql.planner.plan.UnnestNode; +import io.prestosql.sql.planner.plan.ValuesNode; +import io.prestosql.sql.planner.plan.WindowNode; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * Clones plan and assigns new PlanNodeIds to the copied PlanNodes. + * Also, replaces all symbols in the copied plan with new symbols. + * The original and copied plans can be safely used in different + * branches of plan. + */ +public final class PlanCopier +{ + private PlanCopier() {} + + public static NodeAndMappings copyPlan(PlanNode plan, List fields, Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) + { + PlanNode copy = SimplePlanRewriter.rewriteWith(new Copier(idAllocator), plan, null); + return new UnaliasSymbolReferences(metadata).reallocateSymbols(copy, fields, symbolAllocator); + } + + private static class Copier + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + private Copier(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + @Override + protected PlanNode visitPlan(PlanNode node, RewriteContext context) + { + throw new UnsupportedOperationException("plan copying not implemented for " + node.getClass().getSimpleName()); + } + + @Override + public PlanNode visitAggregation(AggregationNode node, RewriteContext context) + { + return new AggregationNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); + } + + @Override + public PlanNode visitFilter(FilterNode node, RewriteContext context) + { + return new FilterNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getPredicate()); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext context) + { + return new ProjectNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getAssignments()); + } + + @Override + public PlanNode visitTopN(TopNNode node, RewriteContext context) + { + return new TopNNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getCount(), node.getOrderingScheme(), node.getStep()); + } + + @Override + public PlanNode visitOffset(OffsetNode node, RewriteContext context) + { + return new OffsetNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getCount()); + } + + @Override + public PlanNode visitLimit(LimitNode node, RewriteContext context) + { + return new LimitNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getCount(), node.getTiesResolvingScheme(), node.isPartial()); + } + + @Override + public PlanNode visitSample(SampleNode node, RewriteContext context) + { + return new SampleNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getSampleRatio(), node.getSampleType()); + } + + @Override + public PlanNode visitTableScan(TableScanNode node, RewriteContext context) + { + return new TableScanNode(idAllocator.getNextId(), node.getTable(), node.getOutputSymbols(), node.getAssignments(), node.getEnforcedConstraint()); + } + + @Override + public PlanNode visitValues(ValuesNode node, RewriteContext context) + { + return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), node.getRows()); + } + + @Override + public PlanNode visitJoin(JoinNode node, RewriteContext context) + { + return new JoinNode( + idAllocator.getNextId(), + node.getType(), + context.rewrite(node.getLeft()), + context.rewrite(node.getRight()), + node.getCriteria(), + node.getLeftOutputSymbols(), + node.getRightOutputSymbols(), + node.getFilter(), + node.getLeftHashSymbol(), + node.getRightHashSymbol(), + node.getDistributionType(), + node.isSpillable(), + node.getDynamicFilters(), + node.getReorderJoinStatsAndCost()); + } + + @Override + public PlanNode visitSort(SortNode node, RewriteContext context) + { + return new SortNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getOrderingScheme(), node.isPartial()); + } + + @Override + public PlanNode visitWindow(WindowNode node, RewriteContext context) + { + return new WindowNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getSpecification(), node.getWindowFunctions(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix()); + } + + @Override + public PlanNode visitUnion(UnionNode node, RewriteContext context) + { + List copiedSources = node.getSources().stream() + .map(context::rewrite) + .collect(toImmutableList()); + return new UnionNode(idAllocator.getNextId(), copiedSources, node.getSymbolMapping(), node.getOutputSymbols()); + } + + @Override + public PlanNode visitIntersect(IntersectNode node, RewriteContext context) + { + List copiedSources = node.getSources().stream() + .map(context::rewrite) + .collect(toImmutableList()); + return new IntersectNode(idAllocator.getNextId(), copiedSources, node.getSymbolMapping(), node.getOutputSymbols()); + } + + @Override + public PlanNode visitExcept(ExceptNode node, RewriteContext context) + { + List copiedSources = node.getSources().stream() + .map(context::rewrite) + .collect(toImmutableList()); + return new ExceptNode(idAllocator.getNextId(), copiedSources, node.getSymbolMapping(), node.getOutputSymbols()); + } + + @Override + public PlanNode visitUnnest(UnnestNode node, RewriteContext context) + { + return new UnnestNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getReplicateSymbols(), node.getMappings(), node.getOrdinalitySymbol(), node.getJoinType(), node.getFilter()); + } + + @Override + public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) + { + return new GroupIdNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getGroupingSets(), node.getGroupingColumns(), node.getAggregationArguments(), node.getGroupIdSymbol()); + } + + @Override + public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext context) + { + return new EnforceSingleRowNode(idAllocator.getNextId(), context.rewrite(node.getSource())); + } + + @Override + public PlanNode visitApply(ApplyNode node, RewriteContext context) + { + return new ApplyNode(idAllocator.getNextId(), context.rewrite(node.getInput()), context.rewrite(node.getSubquery()), node.getSubqueryAssignments(), node.getCorrelation(), node.getOriginSubquery()); + } + + @Override + public PlanNode visitCorrelatedJoin(CorrelatedJoinNode node, RewriteContext context) + { + return new CorrelatedJoinNode(idAllocator.getNextId(), context.rewrite(node.getInput()), context.rewrite(node.getSubquery()), node.getCorrelation(), node.getType(), node.getFilter(), node.getOriginSubquery()); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java index b546c78b12cd..31aa850f6b5a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java @@ -14,12 +14,14 @@ package io.prestosql.sql.planner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import io.prestosql.Session; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.TableHandle; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.type.Type; @@ -39,27 +41,36 @@ import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SimplePlanRewriter; import io.prestosql.sql.planner.plan.SortNode; import io.prestosql.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.prestosql.sql.planner.plan.UnionNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.planner.plan.WindowNode; import io.prestosql.sql.tree.Cast; +import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Delete; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FetchFirst; import io.prestosql.sql.tree.FrameBound; import io.prestosql.sql.tree.FunctionCall; import io.prestosql.sql.tree.FunctionCall.NullTreatment; +import io.prestosql.sql.tree.GenericLiteral; +import io.prestosql.sql.tree.IfExpression; import io.prestosql.sql.tree.LambdaArgumentDeclaration; import io.prestosql.sql.tree.LambdaExpression; import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.Offset; import io.prestosql.sql.tree.OrderBy; +import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.Query; import io.prestosql.sql.tree.QuerySpecification; +import io.prestosql.sql.tree.Relation; import io.prestosql.sql.tree.SortItem; +import io.prestosql.sql.tree.StringLiteral; import io.prestosql.sql.tree.Table; +import io.prestosql.sql.tree.Union; import io.prestosql.sql.tree.Window; import io.prestosql.sql.tree.WindowFrame; import io.prestosql.type.TypeCoercion; @@ -81,10 +92,14 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.SystemSessionProperties.getMaxRecursionDepth; import static io.prestosql.SystemSessionProperties.isSkipRedundantSort; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.VarbinaryType.VARBINARY; +import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.NodeUtils.getSortItemsFromOrderBy; +import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.prestosql.sql.planner.GroupingOperationRewriter.rewriteGroupingOperation; import static io.prestosql.sql.planner.OrderingScheme.sortItemToSortOrder; @@ -92,6 +107,11 @@ import static io.prestosql.sql.planner.ScopeAware.scopeAwareKey; import static io.prestosql.sql.planner.plan.AggregationNode.groupingSets; import static io.prestosql.sql.planner.plan.AggregationNode.singleGroupingSet; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static io.prestosql.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.prestosql.sql.tree.FrameBound.Type.CURRENT_ROW; +import static io.prestosql.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; +import static io.prestosql.sql.tree.WindowFrame.Type.RANGE; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -106,6 +126,7 @@ class QueryPlanner private final Session session; private final SubqueryPlanner subqueryPlanner; private final Optional outerContext; + private final Map, RelationPlan> recursiveSubqueries; QueryPlanner( Analysis analysis, @@ -114,7 +135,8 @@ class QueryPlanner Map, Symbol> lambdaDeclarationToSymbolMap, Metadata metadata, Optional outerContext, - Session session) + Session session, + Map, RelationPlan> recursiveSubqueries) { requireNonNull(analysis, "analysis is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); @@ -123,6 +145,7 @@ class QueryPlanner requireNonNull(metadata, "metadata is null"); requireNonNull(session, "session is null"); requireNonNull(outerContext, "outerContext is null"); + requireNonNull(recursiveSubqueries, "recursiveSubqueries is null"); this.analysis = analysis; this.symbolAllocator = symbolAllocator; @@ -132,7 +155,8 @@ class QueryPlanner this.typeCoercion = new TypeCoercion(metadata::getType); this.session = session; this.outerContext = outerContext; - this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, typeCoercion, outerContext, session); + this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, typeCoercion, outerContext, session, recursiveSubqueries); + this.recursiveSubqueries = recursiveSubqueries; } public RelationPlan plan(Query query) @@ -161,6 +185,170 @@ public RelationPlan plan(Query query) outerContext); } + public RelationPlan planExpand(Query query) + { + checkArgument(analysis.isExpandableQuery(query), "query is not registered as expandable"); + + Union union = (Union) query.getQueryBody(); + ImmutableList.Builder recursionSteps = ImmutableList.builder(); + + // plan anchor relation + Relation anchorNode = union.getRelations().get(0); + RelationPlan anchorPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) + .process(anchorNode, null); + + // prune anchor plan outputs to contain only the symbols exposed in the scope + NodeAndMappings prunedAnchorPlan = pruneInvisibleFields(anchorPlan, idAllocator); + anchorPlan = new RelationPlan(prunedAnchorPlan.getNode(), analysis.getScope(query), prunedAnchorPlan.getFields(), outerContext); + + recursionSteps.add(copy(anchorPlan.getRoot(), anchorPlan.getFieldMappings())); + + // plan recursion step + Relation recursionStepRelation = union.getRelations().get(1); + RelationPlan recursionStepPlan = new RelationPlanner( + analysis, + symbolAllocator, + idAllocator, + lambdaDeclarationToSymbolMap, + metadata, + outerContext, + session, + ImmutableMap.of(NodeRef.of(analysis.getRecursiveReference(query)), anchorPlan)) + .process(recursionStepRelation, null); + + // coerce recursion step outputs and prune them to contain only the symbols exposed in the scope + NodeAndMappings coercedRecursionStep; + List types = analysis.getRelationCoercion(recursionStepRelation); + if (types == null) { + coercedRecursionStep = pruneInvisibleFields(recursionStepPlan, idAllocator); + } + else { + coercedRecursionStep = coerce(recursionStepPlan, types, symbolAllocator, idAllocator); + } + + NodeAndMappings replacementSpot = new NodeAndMappings(anchorPlan.getRoot(), anchorPlan.getFieldMappings()); + PlanNode recursionStep = coercedRecursionStep.getNode(); + List mappings = coercedRecursionStep.getFields(); + + // unroll recursion + int maxRecursionDepth = getMaxRecursionDepth(session); + for (int i = 0; i < maxRecursionDepth; i++) { + recursionSteps.add(copy(recursionStep, mappings)); + NodeAndMappings replacement = copy(recursionStep, mappings); + recursionStep = replace(recursionStep, replacementSpot, replacement); + replacementSpot = replacement; + } + + // after the last recursion step, check if the recursion converged. the last step is expected to return empty result + // 1. append window to count rows + NodeAndMappings checkConvergenceStep = copy(recursionStep, mappings); + Symbol countSymbol = symbolAllocator.newSymbol("count", BIGINT); + ResolvedFunction function = metadata.resolveFunction(QualifiedName.of("count"), ImmutableList.of()); + WindowNode.Frame frame = new WindowNode.Frame(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), CURRENT_ROW, Optional.empty(), Optional.empty(), Optional.empty()); + WindowNode.Function countFunction = new WindowNode.Function(function, ImmutableList.of(), frame, false); + + WindowNode windowNode = new WindowNode( + idAllocator.getNextId(), + checkConvergenceStep.getNode(), + new WindowNode.Specification(ImmutableList.of(), Optional.empty()), + ImmutableMap.of(countSymbol, countFunction), + Optional.empty(), + ImmutableSet.of(), + 0); + + // 2. append filter to fail on non-empty result + ResolvedFunction fail = metadata.resolveFunction(QualifiedName.of("fail"), fromTypes(VARCHAR)); + String recursionLimitExceededMessage = format("Recursion depth limit exceeded (%s). Use 'max_recursion_depth' session property to modify the limit.", maxRecursionDepth); + Expression predicate = new IfExpression( + new ComparisonExpression( + GREATER_THAN_OR_EQUAL, + countSymbol.toSymbolReference(), + new GenericLiteral("BIGINT", "0")), + new Cast( + new FunctionCall( + fail.toQualifiedName(), + ImmutableList.of(new Cast(new StringLiteral(recursionLimitExceededMessage), toSqlType(VARCHAR)))), + toSqlType(BOOLEAN)), + TRUE_LITERAL); + FilterNode filterNode = new FilterNode(idAllocator.getNextId(), windowNode, predicate); + + recursionSteps.add(new NodeAndMappings(filterNode, checkConvergenceStep.getFields())); + + // union all the recursion steps + List recursionStepsToUnion = recursionSteps.build(); + + List unionOutputSymbols = anchorPlan.getFieldMappings().stream() + .map(symbol -> symbolAllocator.newSymbol(symbol, "_expanded")) + .collect(toImmutableList()); + + ImmutableListMultimap.Builder unionSymbolMapping = ImmutableListMultimap.builder(); + for (NodeAndMappings plan : recursionStepsToUnion) { + for (int i = 0; i < unionOutputSymbols.size(); i++) { + unionSymbolMapping.put(unionOutputSymbols.get(i), plan.getFields().get(i)); + } + } + + List nodesToUnion = recursionStepsToUnion.stream() + .map(NodeAndMappings::getNode) + .collect(toImmutableList()); + + PlanNode result = new UnionNode(idAllocator.getNextId(), nodesToUnion, unionSymbolMapping.build(), unionOutputSymbols); + + if (union.isDistinct()) { + result = new AggregationNode( + idAllocator.getNextId(), + result, + ImmutableMap.of(), + singleGroupingSet(result.getOutputSymbols()), + ImmutableList.of(), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty()); + } + + return new RelationPlan(result, anchorPlan.getScope(), unionOutputSymbols, outerContext); + } + + // Return a copy of the plan and remapped field mappings. In the copied plan: + // - all PlanNodeIds are replaced with new values, + // - all symbols are replaced with new symbols. + // Copying the plan might reorder symbols. The returned field mappings keep the original + // order and might be used to identify the original output symbols with their copies. + private NodeAndMappings copy(PlanNode plan, List fields) + { + return PlanCopier.copyPlan(plan, fields, metadata, symbolAllocator, idAllocator); + } + + private PlanNode replace(PlanNode plan, NodeAndMappings replacementSpot, NodeAndMappings replacement) + { + checkArgument( + replacementSpot.getFields().size() == replacement.getFields().size(), + "mismatching outputs in replacement, expected: %s, got: %s", + replacementSpot.getFields().size(), + replacement.getFields().size()); + + return SimplePlanRewriter.rewriteWith(new SimplePlanRewriter() + { + @Override + protected PlanNode visitPlan(PlanNode node, RewriteContext context) + { + return node.replaceChildren(node.getSources().stream() + .map(child -> { + if (child == replacementSpot.getNode()) { + // add projection to adjust symbols + Assignments.Builder assignments = Assignments.builder(); + for (int i = 0; i < replacementSpot.getFields().size(); i++) { + assignments.put(replacementSpot.getFields().get(i), replacement.getFields().get(i).toSymbolReference()); + } + return new ProjectNode(idAllocator.getNextId(), replacement.getNode(), assignments.build()); + } + return context.rewrite(child); + }) + .collect(toImmutableList())); + } + }, plan, null); + } + public RelationPlan plan(QuerySpecification node) { PlanBuilder builder = planFrom(node); @@ -255,7 +443,7 @@ public DeleteNode plan(Delete node) TableHandle handle = analysis.getTableHandle(table); // create table scan - RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session) + RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) .process(table, null); PlanBuilder builder = newPlanBuilder(relationPlan, analysis, lambdaDeclarationToSymbolMap); @@ -283,7 +471,7 @@ private static List computeOutputs(PlanBuilder builder, List private PlanBuilder planQueryBody(Query query) { - RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session) + RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) .process(query.getQueryBody(), null); return newPlanBuilder(relationPlan, analysis, lambdaDeclarationToSymbolMap); @@ -292,7 +480,7 @@ private PlanBuilder planQueryBody(Query query) private PlanBuilder planFrom(QuerySpecification node) { if (node.getFrom().isPresent()) { - RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session) + RelationPlan relationPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) .process(node.getFrom().get(), null); return newPlanBuilder(relationPlan, analysis, lambdaDeclarationToSymbolMap); } @@ -1029,26 +1217,4 @@ public Aggregation getRewritten() return aggregation; } } - - public static class NodeAndMappings - { - private final PlanNode node; - private final List fields; - - public NodeAndMappings(PlanNode node, List fields) - { - this.node = requireNonNull(node, "node is null"); - this.fields = requireNonNull(fields, "fields is null"); - } - - public PlanNode getNode() - { - return node; - } - - public List getFields() - { - return fields; - } - } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index a2ff2b343533..966c1c177013 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -28,7 +28,6 @@ import io.prestosql.sql.analyzer.Field; import io.prestosql.sql.analyzer.RelationType; import io.prestosql.sql.analyzer.Scope; -import io.prestosql.sql.planner.QueryPlanner.NodeAndMappings; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.planner.plan.CorrelatedJoinNode; @@ -114,6 +113,7 @@ class RelationPlanner private final Optional outerContext; private final Session session; private final SubqueryPlanner subqueryPlanner; + private final Map, RelationPlan> recursiveSubqueries; RelationPlanner( Analysis analysis, @@ -122,7 +122,8 @@ class RelationPlanner Map, Symbol> lambdaDeclarationToSymbolMap, Metadata metadata, Optional outerContext, - Session session) + Session session, + Map, RelationPlan> recursiveSubqueries) { requireNonNull(analysis, "analysis is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); @@ -131,6 +132,7 @@ class RelationPlanner requireNonNull(metadata, "metadata is null"); requireNonNull(outerContext, "outerContext is null"); requireNonNull(session, "session is null"); + requireNonNull(recursiveSubqueries, "recursiveSubqueries is null"); this.analysis = analysis; this.symbolAllocator = symbolAllocator; @@ -140,7 +142,8 @@ class RelationPlanner this.typeCoercion = new TypeCoercion(metadata::getType); this.outerContext = outerContext; this.session = session; - this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, typeCoercion, outerContext, session); + this.subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, typeCoercion, outerContext, session, recursiveSubqueries); + this.recursiveSubqueries = recursiveSubqueries; } @Override @@ -152,12 +155,25 @@ protected RelationPlan visitNode(Node node, Void context) @Override protected RelationPlan visitTable(Table node, Void context) { + // is this a recursive reference in expandable named query? If so, there's base relation already planned. + RelationPlan expansion = recursiveSubqueries.get(NodeRef.of(node)); + if (expansion != null) { + return expansion; + } + Query namedQuery = analysis.getNamedQuery(node); Scope scope = analysis.getScope(node); RelationPlan plan; if (namedQuery != null) { - RelationPlan subPlan = process(namedQuery, null); + RelationPlan subPlan; + if (analysis.isExpandableQuery(namedQuery)) { + subPlan = new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) + .planExpand(namedQuery); + } + else { + subPlan = process(namedQuery, null); + } // Add implicit coercions if view query produces types that don't match the declared output types // of the view (e.g., if the underlying tables referenced by the view changed) @@ -636,7 +652,7 @@ private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Latera { PlanBuilder leftPlanBuilder = newPlanBuilder(leftPlan, analysis, lambdaDeclarationToSymbolMap); - RelationPlan rightPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, Optional.of(leftPlanBuilder.getTranslations()), session) + RelationPlan rightPlan = new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, Optional.of(leftPlanBuilder.getTranslations()), session, recursiveSubqueries) .process(lateral.getQuery(), null); PlanBuilder rightPlanBuilder = newPlanBuilder(rightPlan, analysis, lambdaDeclarationToSymbolMap); @@ -753,14 +769,14 @@ protected RelationPlan visitTableSubquery(TableSubquery node, Void context) @Override protected RelationPlan visitQuery(Query node, Void context) { - return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session) + return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) .plan(node); } @Override protected RelationPlan visitQuerySpecification(QuerySpecification node, Void context) { - return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session) + return new QueryPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, outerContext, session, recursiveSubqueries) .plan(node); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java index 00de680990ee..2fdd061c14ee 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java @@ -71,6 +71,7 @@ class SubqueryPlanner private final Metadata metadata; private final TypeCoercion typeCoercion; private final Session session; + private final Map, RelationPlan> recursiveSubqueries; SubqueryPlanner( Analysis analysis, @@ -80,7 +81,8 @@ class SubqueryPlanner Metadata metadata, TypeCoercion typeCoercion, Optional outerContext, - Session session) + Session session, + Map, RelationPlan> recursiveSubqueries) { requireNonNull(analysis, "analysis is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); @@ -90,6 +92,7 @@ class SubqueryPlanner requireNonNull(typeCoercion, "typeCoercion is null"); requireNonNull(outerContext, "outerContext is null"); requireNonNull(session, "session is null"); + requireNonNull(recursiveSubqueries, "recursiveSubqueries is null"); this.analysis = analysis; this.symbolAllocator = symbolAllocator; @@ -98,6 +101,7 @@ class SubqueryPlanner this.metadata = metadata; this.typeCoercion = typeCoercion; this.session = session; + this.recursiveSubqueries = recursiveSubqueries; } public PlanBuilder handleSubqueries(PlanBuilder builder, Collection expressions, Node node) @@ -284,7 +288,7 @@ private PlanBuilder planExists(PlanBuilder subPlan, Cluster clu private RelationPlan planSubquery(Expression subquery, TranslationMap outerContext) { - return new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, Optional.of(outerContext), session) + return new RelationPlanner(analysis, symbolAllocator, idAllocator, lambdaDeclarationToSymbolMap, metadata, Optional.of(outerContext), session, recursiveSubqueries) .process(subquery, null); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTableWriteThroughUnion.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTableWriteThroughUnion.java index c04224a16ae3..25d1838d6373 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTableWriteThroughUnion.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTableWriteThroughUnion.java @@ -34,6 +34,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.prestosql.SystemSessionProperties.isPushTableWriteThroughUnion; import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.optimizations.SymbolMapper.symbolMapper; import static io.prestosql.sql.planner.plan.Patterns.source; import static io.prestosql.sql.planner.plan.Patterns.tableWriterNode; import static io.prestosql.sql.planner.plan.Patterns.union; @@ -107,7 +108,7 @@ private static TableWriterNode rewriteSource( } } sourceMappings.add(outputMappings.build()); - SymbolMapper symbolMapper = new SymbolMapper(mappings.build()); + SymbolMapper symbolMapper = symbolMapper(mappings.build()); return symbolMapper.map(writerNode, unionNode.getSources().get(source), context.getIdAllocator().getNextId()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator.java index be2840e22ec5..ae791846740b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -57,6 +57,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.optimizations.SymbolMapper.symbolMapper; import static io.prestosql.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL; import static java.lang.Math.toIntExact; @@ -516,7 +517,7 @@ private static class DecorrelationResult SymbolMapper getCorrelatedSymbolMapper() { - return new SymbolMapper(correlatedSymbolsMapping.asMap().entrySet().stream() + return symbolMapper(correlatedSymbolsMapping.asMap().entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, symbols -> Iterables.getLast(symbols.getValue())))); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/SymbolMapper.java index d6dbe0a631f5..081384bc2fe5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/SymbolMapper.java @@ -19,6 +19,7 @@ import io.prestosql.sql.planner.OrderingScheme; import io.prestosql.sql.planner.PartitioningScheme; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolAllocator; import io.prestosql.sql.planner.plan.AggregationNode; import io.prestosql.sql.planner.plan.AggregationNode.Aggregation; import io.prestosql.sql.planner.plan.DistinctLimitNode; @@ -45,6 +46,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.function.Function; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -54,25 +56,44 @@ public class SymbolMapper { - private final Map mapping; + private final Function mappingFunction; - public SymbolMapper(Map mapping) + private SymbolMapper(Function mappingFunction) { - this.mapping = ImmutableMap.copyOf(requireNonNull(mapping, "mapping is null")); + this.mappingFunction = requireNonNull(mappingFunction, "mappingFunction is null"); } - public Map getMapping() + public static SymbolMapper symbolMapper(Map mapping) { - return mapping; + return new SymbolMapper(symbol -> { + while (mapping.containsKey(symbol) && !mapping.get(symbol).equals(symbol)) { + symbol = mapping.get(symbol); + } + return symbol; + }); + } + + public static SymbolMapper symbolReallocator(Map mapping, SymbolAllocator symbolAllocator) + { + return new SymbolMapper(symbol -> { + if (mapping.containsKey(symbol)) { + while (mapping.containsKey(symbol) && !mapping.get(symbol).equals(symbol)) { + symbol = mapping.get(symbol); + } + return symbol; + } + Symbol newSymbol = symbolAllocator.newSymbol(symbol); + mapping.put(symbol, newSymbol); + // do not remap the symbol further + mapping.put(newSymbol, newSymbol); + return newSymbol; + }); } // Return the canonical mapping for the symbol. public Symbol map(Symbol symbol) { - while (mapping.containsKey(symbol) && !mapping.get(symbol).equals(symbol)) { - symbol = mapping.get(symbol); - } - return symbol; + return mappingFunction.apply(symbol); } public List map(List symbols) @@ -363,7 +384,7 @@ public void put(Symbol from, Symbol to) public SymbolMapper build() { - return new SymbolMapper(mappings.build()); + return SymbolMapper.symbolMapper(mappings.build()); } } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java index 662102d33939..d15bab44391c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -24,6 +24,7 @@ import io.prestosql.metadata.Metadata; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.sql.planner.DeterminismEvaluator; +import io.prestosql.sql.planner.NodeAndMappings; import io.prestosql.sql.planner.OrderingScheme; import io.prestosql.sql.planner.PartitioningScheme; import io.prestosql.sql.planner.PlanNodeIdAllocator; @@ -85,11 +86,14 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.sql.planner.optimizations.SymbolMapper.symbolMapper; +import static io.prestosql.sql.planner.optimizations.SymbolMapper.symbolReallocator; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; import static java.util.Objects.requireNonNull; @@ -123,17 +127,40 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - return plan.accept(new Visitor(metadata), UnaliasContext.empty()).getRoot(); + return plan.accept(new Visitor(metadata, SymbolMapper::symbolMapper), UnaliasContext.empty()).getRoot(); + } + + /** + * Replace all symbols in the plan with new symbols. + * The returned plan has different output than the original plan. Also, the order of symbols might change during symbol replacement. + * Symbols in the list `fields` are replaced maintaining the order so they might be used to match original symbols with their replacements. + * Replacing symbols helps avoid collisions when symbols or parts of the plan are reused. + */ + public NodeAndMappings reallocateSymbols(PlanNode plan, List fields, SymbolAllocator symbolAllocator) + { + requireNonNull(plan, "plan is null"); + requireNonNull(fields, "fields is null"); + requireNonNull(symbolAllocator, "symbolAllocator is null"); + + PlanAndMappings result = plan.accept(new Visitor(metadata, mapping -> symbolReallocator(mapping, symbolAllocator)), UnaliasContext.empty()); + return new NodeAndMappings(result.getRoot(), symbolMapper(result.getMappings()).map(fields)); } private static class Visitor extends PlanVisitor { private final Metadata metadata; + private final Function, SymbolMapper> mapperProvider; - public Visitor(Metadata metadata) + public Visitor(Metadata metadata, Function, SymbolMapper> mapperProvider) { this.metadata = requireNonNull(metadata, "metadata is null"); + this.mapperProvider = requireNonNull(mapperProvider, "mapperProvider is null"); + } + + private SymbolMapper symbolMapper(Map mappings) + { + return mapperProvider.apply(mappings); } @Override @@ -146,42 +173,46 @@ protected PlanAndMappings visitPlan(PlanNode node, UnaliasContext context) public PlanAndMappings visitAggregation(AggregationNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); AggregationNode rewrittenAggregation = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenAggregation, mapper.getMapping()); + return new PlanAndMappings(rewrittenAggregation, mapping); } @Override public PlanAndMappings visitGroupId(GroupIdNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); GroupIdNode rewrittenGroupId = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenGroupId, mapper.getMapping()); + return new PlanAndMappings(rewrittenGroupId, mapping); } @Override public PlanAndMappings visitExplainAnalyze(ExplainAnalyzeNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); Symbol newOutputSymbol = mapper.map(node.getOutputSymbol()); return new PlanAndMappings( new ExplainAnalyzeNode(node.getId(), rewrittenSource.getRoot(), newOutputSymbol, node.isVerbose()), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitMarkDistinct(MarkDistinctNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); Symbol newMarkerSymbol = mapper.map(node.getMarkerSymbol()); List newDistinctSymbols = mapper.mapAndDistinct(node.getDistinctSymbols()); @@ -194,14 +225,15 @@ public PlanAndMappings visitMarkDistinct(MarkDistinctNode node, UnaliasContext c newMarkerSymbol, newDistinctSymbols, newHashSymbol), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitUnnest(UnnestNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); List newReplicateSymbols = mapper.mapAndDistinct(node.getReplicateSymbols()); @@ -222,24 +254,26 @@ public PlanAndMappings visitUnnest(UnnestNode node, UnaliasContext context) newOrdinalitySymbol, node.getJoinType(), newFilter), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitWindow(WindowNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); WindowNode rewrittenWindow = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenWindow, mapper.getMapping()); + return new PlanAndMappings(rewrittenWindow, mapping); } @Override public PlanAndMappings visitTableScan(TableScanNode node, UnaliasContext context) { - SymbolMapper mapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); List newOutputs = mapper.map(node.getOutputSymbols()); @@ -250,7 +284,7 @@ public PlanAndMappings visitTableScan(TableScanNode node, UnaliasContext context return new PlanAndMappings( new TableScanNode(node.getId(), node.getTable(), newOutputs, newAssignments, node.getEnforcedConstraint()), - mapper.getMapping()); + mapping); } @Override @@ -263,13 +297,14 @@ public PlanAndMappings visitExchange(ExchangeNode node, UnaliasContext context) for (int i = 0; i < node.getSources().size(); i++) { PlanAndMappings rewrittenChild = node.getSources().get(i).accept(this, context); rewrittenChildren.add(rewrittenChild.getRoot()); - SymbolMapper mapper = new SymbolMapper(rewrittenChild.getMappings()); + SymbolMapper mapper = symbolMapper(new HashMap<>(rewrittenChild.getMappings())); rewrittenInputsBuilder.add(mapper.map(node.getInputs().get(i))); } List> rewrittenInputs = rewrittenInputsBuilder.build(); // canonicalize ExchangeNode outputs - SymbolMapper mapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); List rewrittenOutputs = mapper.map(node.getOutputSymbols()); // sanity check: assert that duplicate outputs result from same inputs @@ -317,10 +352,10 @@ public PlanAndMappings visitExchange(ExchangeNode node, UnaliasContext context) } Map outputMapping = new HashMap<>(); - outputMapping.putAll(mapper.getMapping()); + outputMapping.putAll(mapping); outputMapping.putAll(newMapping); - mapper = new SymbolMapper(outputMapping); + mapper = symbolMapper(outputMapping); // deduplicate outputs and prune input symbols lists accordingly List> newInputs = new ArrayList<>(); @@ -354,13 +389,14 @@ public PlanAndMappings visitExchange(ExchangeNode node, UnaliasContext context) rewrittenChildren.build(), newInputs, newOrderingScheme), - mapper.getMapping()); + outputMapping); } @Override public PlanAndMappings visitRemoteSource(RemoteSourceNode node, UnaliasContext context) { - SymbolMapper mapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); List newOutputs = mapper.mapAndDistinct(node.getOutputSymbols()); Optional newOrderingScheme = node.getOrderingScheme().map(mapper::map); @@ -372,7 +408,7 @@ public PlanAndMappings visitRemoteSource(RemoteSourceNode node, UnaliasContext c newOutputs, newOrderingScheme, node.getExchangeType()), - mapper.getMapping()); + mapping); } @Override @@ -389,22 +425,24 @@ public PlanAndMappings visitOffset(OffsetNode node, UnaliasContext context) public PlanAndMappings visitLimit(LimitNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); LimitNode rewrittenLimit = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenLimit, mapper.getMapping()); + return new PlanAndMappings(rewrittenLimit, mapping); } @Override public PlanAndMappings visitDistinctLimit(DistinctLimitNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); DistinctLimitNode rewrittenDistinctLimit = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenDistinctLimit, mapper.getMapping()); + return new PlanAndMappings(rewrittenDistinctLimit, mapping); } @Override @@ -420,7 +458,8 @@ public PlanAndMappings visitSample(SampleNode node, UnaliasContext context) @Override public PlanAndMappings visitValues(ValuesNode node, UnaliasContext context) { - SymbolMapper mapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); List> newRows = node.getRows().stream() .map(row -> row.stream() @@ -433,26 +472,28 @@ public PlanAndMappings visitValues(ValuesNode node, UnaliasContext context) return new PlanAndMappings( new ValuesNode(node.getId(), newOutputSymbols, newRows), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitTableDelete(TableDeleteNode node, UnaliasContext context) { - SymbolMapper mapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); Symbol newOutput = mapper.map(node.getOutput()); return new PlanAndMappings( new TableDeleteNode(node.getId(), node.getTarget(), newOutput), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitDelete(DeleteNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); Symbol newRowId = mapper.map(node.getRowId()); List newOutputs = mapper.map(node.getOutputSymbols()); @@ -464,99 +505,107 @@ public PlanAndMappings visitDelete(DeleteNode node, UnaliasContext context) node.getTarget(), newRowId, newOutputs), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitStatisticsWriterNode(StatisticsWriterNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); StatisticsWriterNode rewrittenStatisticsWriter = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenStatisticsWriter, mapper.getMapping()); + return new PlanAndMappings(rewrittenStatisticsWriter, mapping); } @Override public PlanAndMappings visitTableWriter(TableWriterNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); TableWriterNode rewrittenTableWriter = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenTableWriter, mapper.getMapping()); + return new PlanAndMappings(rewrittenTableWriter, mapping); } @Override public PlanAndMappings visitTableFinish(TableFinishNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); TableFinishNode rewrittenTableFinish = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenTableFinish, mapper.getMapping()); + return new PlanAndMappings(rewrittenTableFinish, mapping); } @Override public PlanAndMappings visitRowNumber(RowNumberNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); RowNumberNode rewrittenRowNumber = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenRowNumber, mapper.getMapping()); + return new PlanAndMappings(rewrittenRowNumber, mapping); } @Override public PlanAndMappings visitTopNRowNumber(TopNRowNumberNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); TopNRowNumberNode rewrittenTopNRowNumber = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenTopNRowNumber, mapper.getMapping()); + return new PlanAndMappings(rewrittenTopNRowNumber, mapping); } @Override public PlanAndMappings visitTopN(TopNNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); TopNNode rewrittenTopN = mapper.map(node, rewrittenSource.getRoot()); - return new PlanAndMappings(rewrittenTopN, mapper.getMapping()); + return new PlanAndMappings(rewrittenTopN, mapping); } @Override public PlanAndMappings visitSort(SortNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); OrderingScheme newOrderingScheme = mapper.map(node.getOrderingScheme()); return new PlanAndMappings( new SortNode(node.getId(), rewrittenSource.getRoot(), newOrderingScheme, node.isPartial()), - mapper.getMapping()); + mapping); } @Override public PlanAndMappings visitFilter(FilterNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); Expression newPredicate = mapper.map(node.getPredicate()); return new PlanAndMappings( new FilterNode(node.getId(), rewrittenSource.getRoot(), newPredicate), - mapper.getMapping()); + mapping); } @Override @@ -590,7 +639,8 @@ public PlanAndMappings visitProject(ProjectNode node, UnaliasContext context) .build(); boolean ambiguousSymbolsPresent = !Sets.intersection(newlyAssignedSymbols, Sets.difference(symbolsInSourceMapping, symbolsInCorrelationMapping)).isEmpty(); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); // canonicalize ProjectNode assignments ImmutableList.Builder> rewrittenAssignments = ImmutableList.builder(); @@ -609,10 +659,10 @@ public PlanAndMappings visitProject(ProjectNode node, UnaliasContext context) Map newMapping = mappingFromAssignments(deduplicateAssignments, ambiguousSymbolsPresent); Map outputMapping = new HashMap<>(); - outputMapping.putAll(ambiguousSymbolsPresent ? context.getCorrelationMapping() : mapper.getMapping()); + outputMapping.putAll(ambiguousSymbolsPresent ? context.getCorrelationMapping() : mapping); outputMapping.putAll(newMapping); - mapper = new SymbolMapper(outputMapping); + mapper = symbolMapper(outputMapping); // build new Assignments with canonical outputs // duplicate entries will be removed by the Builder @@ -623,7 +673,7 @@ public PlanAndMappings visitProject(ProjectNode node, UnaliasContext context) return new PlanAndMappings( new ProjectNode(node.getId(), rewrittenSource.getRoot(), newAssignments.build()), - mapper.getMapping()); + outputMapping); } private Map mappingFromAssignments(Map assignments, boolean ambiguousSymbolsPresent) @@ -661,13 +711,14 @@ else if (DeterminismEvaluator.isDeterministic(expression, metadata) && !(express public PlanAndMappings visitOutput(OutputNode node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); List newOutputs = mapper.map(node.getOutputSymbols()); return new PlanAndMappings( new OutputNode(node.getId(), rewrittenSource.getRoot(), node.getColumnNames(), newOutputs), - mapper.getMapping()); + mapping); } @Override @@ -684,13 +735,14 @@ public PlanAndMappings visitEnforceSingleRow(EnforceSingleRowNode node, UnaliasC public PlanAndMappings visitAssignUniqueId(AssignUniqueId node, UnaliasContext context) { PlanAndMappings rewrittenSource = node.getSource().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenSource.getMappings()); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); Symbol newUnique = mapper.map(node.getIdColumn()); return new PlanAndMappings( new AssignUniqueId(node.getId(), rewrittenSource.getRoot(), newUnique), - mapper.getMapping()); + mapping); } @Override @@ -699,14 +751,15 @@ public PlanAndMappings visitApply(ApplyNode node, UnaliasContext context) // it is assumed that apart from correlation (and possibly outer correlation), symbols are distinct between Input and Subquery // rewrite Input PlanAndMappings rewrittenInput = node.getInput().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenInput.getMappings()); + Map inputMapping = new HashMap<>(rewrittenInput.getMappings()); + SymbolMapper mapper = symbolMapper(inputMapping); // rewrite correlation with mapping from Input List rewrittenCorrelation = mapper.mapAndDistinct(node.getCorrelation()); // extract new mappings for correlation symbols to apply in Subquery Set correlationSymbols = ImmutableSet.copyOf(node.getCorrelation()); - Map correlationMapping = mapper.getMapping().entrySet().stream() + Map correlationMapping = inputMapping.entrySet().stream() .filter(mapping -> correlationSymbols.contains(mapping.getKey())) .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); @@ -721,7 +774,7 @@ public PlanAndMappings visitApply(ApplyNode node, UnaliasContext context) Map resultMapping = new HashMap<>(); resultMapping.putAll(rewrittenInput.getMappings()); resultMapping.putAll(rewrittenSubquery.getMappings()); - mapper = new SymbolMapper(resultMapping); + mapper = symbolMapper(resultMapping); ImmutableList.Builder> rewrittenAssignments = ImmutableList.builder(); for (Map.Entry assignment : node.getSubqueryAssignments().entrySet()) { @@ -737,10 +790,10 @@ public PlanAndMappings visitApply(ApplyNode node, UnaliasContext context) Map newMapping = mappingFromAssignments(deduplicateAssignments, false); Map assignmentsOutputMapping = new HashMap<>(); - assignmentsOutputMapping.putAll(mapper.getMapping()); + assignmentsOutputMapping.putAll(resultMapping); assignmentsOutputMapping.putAll(newMapping); - mapper = new SymbolMapper(assignmentsOutputMapping); + mapper = symbolMapper(assignmentsOutputMapping); // build new Assignments with canonical outputs // duplicate entries will be removed by the Builder @@ -751,7 +804,7 @@ public PlanAndMappings visitApply(ApplyNode node, UnaliasContext context) return new PlanAndMappings( new ApplyNode(node.getId(), rewrittenInput.getRoot(), rewrittenSubquery.getRoot(), newAssignments.build(), rewrittenCorrelation, node.getOriginSubquery()), - mapper.getMapping()); + assignmentsOutputMapping); } @Override @@ -760,14 +813,15 @@ public PlanAndMappings visitCorrelatedJoin(CorrelatedJoinNode node, UnaliasConte // it is assumed that apart from correlation (and possibly outer correlation), symbols are distinct between left and right CorrelatedJoin source // rewrite Input PlanAndMappings rewrittenInput = node.getInput().accept(this, context); - SymbolMapper mapper = new SymbolMapper(rewrittenInput.getMappings()); + Map inputMapping = new HashMap<>(rewrittenInput.getMappings()); + SymbolMapper mapper = symbolMapper(inputMapping); // rewrite correlation with mapping from Input List rewrittenCorrelation = mapper.mapAndDistinct(node.getCorrelation()); // extract new mappings for correlation symbols to apply in Subquery Set correlationSymbols = ImmutableSet.copyOf(node.getCorrelation()); - Map correlationMapping = mapper.getMapping().entrySet().stream() + Map correlationMapping = inputMapping.entrySet().stream() .filter(mapping -> correlationSymbols.contains(mapping.getKey())) .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); @@ -784,12 +838,12 @@ public PlanAndMappings visitCorrelatedJoin(CorrelatedJoinNode node, UnaliasConte resultMapping.putAll(rewrittenSubquery.getMappings()); // rewrite filter with unified mapping - mapper = new SymbolMapper(resultMapping); + mapper = symbolMapper(resultMapping); Expression newFilter = mapper.map(node.getFilter()); return new PlanAndMappings( new CorrelatedJoinNode(node.getId(), rewrittenInput.getRoot(), rewrittenSubquery.getRoot(), rewrittenCorrelation, node.getType(), newFilter, node.getOriginSubquery()), - mapper.getMapping()); + resultMapping); } @Override @@ -804,7 +858,7 @@ public PlanAndMappings visitJoin(JoinNode node, UnaliasContext context) unifiedMapping.putAll(rewrittenLeft.getMappings()); unifiedMapping.putAll(rewrittenRight.getMappings()); - SymbolMapper mapper = new SymbolMapper(unifiedMapping); + SymbolMapper mapper = symbolMapper(unifiedMapping); ImmutableList.Builder builder = ImmutableList.builder(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { @@ -838,10 +892,10 @@ public PlanAndMappings visitJoin(JoinNode node, UnaliasContext context) } Map outputMapping = new HashMap<>(); - outputMapping.putAll(mapper.getMapping()); + outputMapping.putAll(unifiedMapping); outputMapping.putAll(newMapping); - mapper = new SymbolMapper(outputMapping); + mapper = symbolMapper(outputMapping); List canonicalOutputs = mapper.mapAndDistinct(node.getOutputSymbols()); List newLeftOutputSymbols = canonicalOutputs.stream() .filter(rewrittenLeft.getRoot().getOutputSymbols()::contains) @@ -866,7 +920,7 @@ public PlanAndMappings visitJoin(JoinNode node, UnaliasContext context) node.isSpillable(), newDynamicFilters, node.getReorderJoinStatsAndCost()), - mapper.getMapping()); + outputMapping); } @Override @@ -880,7 +934,7 @@ public PlanAndMappings visitSemiJoin(SemiJoinNode node, UnaliasContext context) outputMapping.putAll(rewrittenSource.getMappings()); outputMapping.putAll(rewrittenFilteringSource.getMappings()); - SymbolMapper mapper = new SymbolMapper(outputMapping); + SymbolMapper mapper = symbolMapper(outputMapping); Symbol newSourceJoinSymbol = mapper.map(node.getSourceJoinSymbol()); Symbol newFilteringSourceJoinSymbol = mapper.map(node.getFilteringSourceJoinSymbol()); @@ -899,7 +953,7 @@ public PlanAndMappings visitSemiJoin(SemiJoinNode node, UnaliasContext context) newSourceHashSymbol, newFilteringSourceHashSymbol, node.getDistributionType()), - mapper.getMapping()); + outputMapping); } @Override @@ -913,7 +967,7 @@ public PlanAndMappings visitSpatialJoin(SpatialJoinNode node, UnaliasContext con outputMapping.putAll(rewrittenLeft.getMappings()); outputMapping.putAll(rewrittenRight.getMappings()); - SymbolMapper mapper = new SymbolMapper(outputMapping); + SymbolMapper mapper = symbolMapper(outputMapping); List newOutputSymbols = mapper.mapAndDistinct(node.getOutputSymbols()); Expression newFilter = mapper.map(node.getFilter()); @@ -922,7 +976,7 @@ public PlanAndMappings visitSpatialJoin(SpatialJoinNode node, UnaliasContext con return new PlanAndMappings( new SpatialJoinNode(node.getId(), node.getType(), rewrittenLeft.getRoot(), rewrittenRight.getRoot(), newOutputSymbols, newFilter, newLeftPartitionSymbol, newRightPartitionSymbol, node.getKdbTree()), - mapper.getMapping()); + outputMapping); } @Override @@ -936,7 +990,7 @@ public PlanAndMappings visitIndexJoin(IndexJoinNode node, UnaliasContext context outputMapping.putAll(rewrittenProbe.getMappings()); outputMapping.putAll(rewrittenIndex.getMappings()); - SymbolMapper mapper = new SymbolMapper(outputMapping); + SymbolMapper mapper = symbolMapper(outputMapping); // canonicalize index join criteria ImmutableList.Builder builder = ImmutableList.builder(); @@ -950,13 +1004,14 @@ public PlanAndMappings visitIndexJoin(IndexJoinNode node, UnaliasContext context return new PlanAndMappings( new IndexJoinNode(node.getId(), node.getType(), rewrittenProbe.getRoot(), rewrittenIndex.getRoot(), newEquiCriteria, newProbeHashSymbol, newIndexHashSymbol), - mapper.getMapping()); + outputMapping); } @Override public PlanAndMappings visitIndexSource(IndexSourceNode node, UnaliasContext context) { - SymbolMapper mapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); Set newLookupSymbols = node.getLookupSymbols().stream() .map(mapper::map) @@ -969,7 +1024,7 @@ public PlanAndMappings visitIndexSource(IndexSourceNode node, UnaliasContext con return new PlanAndMappings( new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbols, newOutputSymbols, newAssignments), - mapper.getMapping()); + mapping); } @Override @@ -980,10 +1035,11 @@ public PlanAndMappings visitUnion(UnionNode node, UnaliasContext context) .collect(toImmutableList()); List inputMappers = rewrittenSources.stream() - .map(source -> new SymbolMapper(source.getMappings())) + .map(source -> symbolMapper(new HashMap<>(source.getMappings()))) .collect(toImmutableList()); - SymbolMapper outputMapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper outputMapper = symbolMapper(mapping); ListMultimap newOutputToInputs = rewriteOutputToInputsMap(node.getSymbolMapping(), outputMapper, inputMappers); List newOutputs = outputMapper.mapAndDistinct(node.getOutputSymbols()); @@ -996,7 +1052,7 @@ public PlanAndMappings visitUnion(UnionNode node, UnaliasContext context) .collect(toImmutableList()), newOutputToInputs, newOutputs), - outputMapper.getMapping()); + mapping); } @Override @@ -1007,10 +1063,11 @@ public PlanAndMappings visitIntersect(IntersectNode node, UnaliasContext context .collect(toImmutableList()); List inputMappers = rewrittenSources.stream() - .map(source -> new SymbolMapper(source.getMappings())) + .map(source -> symbolMapper(new HashMap<>(source.getMappings()))) .collect(toImmutableList()); - SymbolMapper outputMapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper outputMapper = symbolMapper(mapping); ListMultimap newOutputToInputs = rewriteOutputToInputsMap(node.getSymbolMapping(), outputMapper, inputMappers); List newOutputs = outputMapper.mapAndDistinct(node.getOutputSymbols()); @@ -1023,7 +1080,7 @@ public PlanAndMappings visitIntersect(IntersectNode node, UnaliasContext context .collect(toImmutableList()), newOutputToInputs, newOutputs), - outputMapper.getMapping()); + mapping); } @Override @@ -1034,10 +1091,11 @@ public PlanAndMappings visitExcept(ExceptNode node, UnaliasContext context) .collect(toImmutableList()); List inputMappers = rewrittenSources.stream() - .map(source -> new SymbolMapper(source.getMappings())) + .map(source -> symbolMapper(new HashMap<>(source.getMappings()))) .collect(toImmutableList()); - SymbolMapper outputMapper = new SymbolMapper(context.getCorrelationMapping()); + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper outputMapper = symbolMapper(mapping); ListMultimap newOutputToInputs = rewriteOutputToInputsMap(node.getSymbolMapping(), outputMapper, inputMappers); List newOutputs = outputMapper.mapAndDistinct(node.getOutputSymbols()); @@ -1050,7 +1108,7 @@ public PlanAndMappings visitExcept(ExceptNode node, UnaliasContext context) .collect(toImmutableList()), newOutputToInputs, newOutputs), - outputMapper.getMapping()); + mapping); } private ListMultimap rewriteOutputToInputsMap(ListMultimap oldMapping, SymbolMapper outputMapper, List inputMappers) diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java index 81be675f0c9e..fad353d794d8 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java @@ -80,18 +80,22 @@ import static io.prestosql.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.prestosql.spi.StandardErrorCode.INVALID_COLUMN_REFERENCE; import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.prestosql.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; import static io.prestosql.spi.StandardErrorCode.INVALID_LITERAL; import static io.prestosql.spi.StandardErrorCode.INVALID_PARAMETER_USAGE; +import static io.prestosql.spi.StandardErrorCode.INVALID_RECURSIVE_REFERENCE; import static io.prestosql.spi.StandardErrorCode.INVALID_VIEW; import static io.prestosql.spi.StandardErrorCode.INVALID_WINDOW_FRAME; import static io.prestosql.spi.StandardErrorCode.MISMATCHED_COLUMN_ALIASES; import static io.prestosql.spi.StandardErrorCode.MISSING_CATALOG_NAME; +import static io.prestosql.spi.StandardErrorCode.MISSING_COLUMN_ALIASES; import static io.prestosql.spi.StandardErrorCode.MISSING_COLUMN_NAME; import static io.prestosql.spi.StandardErrorCode.MISSING_GROUP_BY; import static io.prestosql.spi.StandardErrorCode.MISSING_ORDER_BY; import static io.prestosql.spi.StandardErrorCode.MISSING_OVER; import static io.prestosql.spi.StandardErrorCode.MISSING_SCHEMA_NAME; import static io.prestosql.spi.StandardErrorCode.NESTED_AGGREGATION; +import static io.prestosql.spi.StandardErrorCode.NESTED_RECURSIVE; import static io.prestosql.spi.StandardErrorCode.NESTED_WINDOW; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; import static io.prestosql.spi.StandardErrorCode.NULL_TREATMENT_NOT_ALLOWED; @@ -119,6 +123,8 @@ import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType; import static io.prestosql.spi.type.VarcharType.createVarcharType; +import static io.prestosql.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; +import static io.prestosql.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE; import static io.prestosql.testing.TestingEventListenerManager.emptyEventListenerManager; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static io.prestosql.testing.assertions.PrestoExceptionAssert.assertPrestoExceptionThrownBy; @@ -1112,6 +1118,11 @@ public void testDuplicateWithQuery() " a AS (SELECT * FROM t1)" + "SELECT * FROM a") .hasErrorCode(DUPLICATE_NAMED_QUERY); + + assertFails("WITH RECURSIVE a(w, x, y, z) AS (SELECT * FROM t1)," + + " a(a, b, c, d) AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(DUPLICATE_NAMED_QUERY); } @Test @@ -1121,6 +1132,11 @@ public void testCaseInsensitiveDuplicateWithQuery() " A AS (SELECT * FROM t1)" + "SELECT * FROM a") .hasErrorCode(DUPLICATE_NAMED_QUERY); + + assertFails("WITH RECURSIVE a(w, x, y, z) AS (SELECT * FROM t1)," + + " A(a, b, c, d) AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(DUPLICATE_NAMED_QUERY); } @Test @@ -1132,6 +1148,475 @@ public void testWithForwardReference() .hasErrorCode(TABLE_NOT_FOUND); } + @Test + public void testMultipleWithListEntries() + { + analyze("WITH a(x) AS (SELECT 1)," + + " b(y) AS (SELECT x + 1 FROM a)," + + " c(z) AS (SELECT y * 10 FROM b)" + + "SELECT * FROM a, b, c"); + + analyze("WITH RECURSIVE a(x) AS (SELECT 1)," + + " b(y) AS (" + + " SELECT x FROM a" + + " UNION ALL" + + " SELECT y + 1 FROM b WHERE y < 3)," + + " c(z) AS (" + + " SELECT y FROM b" + + " UNION ALL" + + " SELECT z - 1 FROM c WHERE z > 0)" + + "SELECT * FROM a, b, c"); + } + + @Test + public void testWithQueryInvalidAliases() + { + assertFails("WITH a(x) AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(MISMATCHED_COLUMN_ALIASES); + + assertFails("WITH a(x, y, z, x) AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(DUPLICATE_COLUMN_NAME); + + // effectively non recursive + assertFails("WITH RECURSIVE a(x) AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(MISMATCHED_COLUMN_ALIASES); + + assertFails("WITH RECURSIVE a(x, y, z, x) AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(DUPLICATE_COLUMN_NAME); + + assertFails("WITH RECURSIVE a AS (SELECT * FROM t1)" + + "SELECT * FROM a") + .hasErrorCode(MISSING_COLUMN_ALIASES); + + // effectively recursive + assertFails("WITH RECURSIVE t(n, m) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(MISMATCHED_COLUMN_ALIASES); + + assertFails("WITH RECURSIVE t(n, n) AS (" + + " SELECT 1, 2" + + " UNION ALL" + + " SELECT n + 2, m - 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(DUPLICATE_COLUMN_NAME); + + assertFails("WITH RECURSIVE t AS (" + + " SELECT 1, 2" + + " UNION ALL" + + " SELECT n + 2, m - 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(MISSING_COLUMN_ALIASES); + } + + @Test + public void testRecursiveBaseRelationAliasing() + { + // base relation anonymous + analyze("WITH RECURSIVE t(n, m) AS (" + + " SELECT * FROM (VALUES(1, 2), (4, 100))" + + " UNION ALL" + + " SELECT n + 1, m - 1 FROM t WHERE n < 5" + + " )" + + " SELECT * from t"); + + // base relation aliased same as WITH query resulting table + analyze("WITH RECURSIVE t(n, m) AS (" + + " SELECT * FROM (VALUES(1, 2), (4, 100)) AS T(n, m)" + + " UNION ALL" + + " SELECT n + 1, m - 1 FROM t WHERE n < 5" + + " )" + + " SELECT * from t"); + + // base relation aliased different than WITH query resulting table + analyze("WITH RECURSIVE t(n, m) AS (" + + " SELECT * FROM (VALUES(1, 2), (4, 100)) AS T1(x1, y1)" + + " UNION ALL" + + " SELECT n + 1, m - 1 FROM t WHERE n < 5" + + " )" + + " SELECT * from t"); + + // same aliases for base relation and WITH query resulting table, different order + analyze("WITH RECURSIVE t(n, m) AS (" + + " SELECT * FROM (VALUES(1, 2), (4, 100)) AS T(m, n)" + + " UNION ALL" + + " SELECT n + 1, m - 1 FROM t WHERE n < 5" + + " )" + + " SELECT * from t"); + } + + @Test + public void testColumnNumberMismatch() + { + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2, n + 10 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH); + + assertFails("WITH RECURSIVE t(n, m) AS (" + + " SELECT 1, 2" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH); + } + + @Test + public void testNestedWith() + { + // effectively non recursive + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT * FROM (WITH RECURSIVE t2(m) AS (SELECT 1) SELECT m FROM t2)" + + " )" + + " SELECT * from t") + .hasErrorCode(NESTED_RECURSIVE); + + analyze("WITH t(n) AS (" + + " SELECT * FROM (WITH RECURSIVE t2(m) AS (SELECT 1) SELECT m FROM t2)" + + " )" + + " SELECT * from t"); + + analyze("WITH RECURSIVE t(n) AS (" + + " SELECT * FROM (WITH t2(m) AS (SELECT 1) SELECT m FROM t2)" + + " )" + + " SELECT * from t"); + + // effectively recursive + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH RECURSIVE t2(m) AS (SELECT 4) SELECT m FROM t2 UNION SELECT n + 1 FROM t) t(n) WHERE n < 4" + + " )" + + " SELECT * from t") + .hasErrorCode(NESTED_RECURSIVE); + + analyze("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t2(m) AS (SELECT 4) SELECT m FROM t2 UNION SELECT n + 1 FROM t) t(n) WHERE n < 4" + + " )" + + " SELECT * from t"); + } + + @Test + public void testParenthesedRecursionStep() + { + analyze("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (((SELECT n + 2 FROM t WHERE n < 6)))" + + " )" + + " SELECT * from t"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (((TABLE t)))" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (((SELECT n + 2 FROM t WHERE n < 6) LIMIT 1))" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_LIMIT_CLAUSE); + } + + @Test + public void testInvalidRecursiveReference() + { + // WITH table name is referenced in the base relation of recursion + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1 FROM T" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + + // multiple recursive references in the step relation of recursion + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT a.n + 2 FROM t AS a, t AS b WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + + // step relation of recursion is not a query specification + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " TABLE T" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + + // step relation of recursion is a query specification without FROM clause + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT 2 WHERE (SELECT true FROM t)" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + + // step relation of recursion is a query specification with a FROM clause, but the recursive reference is not in the FROM clause + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT m FROM (VALUES 2) t2(m) WHERE (SELECT true FROM t)" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + + // not a well-formed RECURSIVE query with recursive reference + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " INTERSECT" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE); + } + + @Test + public void testWithRecursiveUnsupportedClauses() + { + // immediate WITH clause in recursive query + assertFails("WITH RECURSIVE t(n) AS (" + + " WITH t2(m) AS (SELECT 1)" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t") + .hasErrorCode(NOT_SUPPORTED); + + // immediate ORDER BY clause in recursive query + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " ORDER BY 1" + + " )" + + " SELECT * from t") + .hasErrorCode(NOT_SUPPORTED); + + // immediate OFFSET clause in recursive query + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " OFFSET 1" + + " )" + + " SELECT * from t") + .hasErrorCode(NOT_SUPPORTED); + + // immediate LIMIT clause in recursive query + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " LIMIT 1" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_LIMIT_CLAUSE); + + // immediate FETCH FIRST clause in recursive query + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " FETCH FIRST 1 ROW ONLY" + + " )" + + " SELECT * from t") + .hasErrorCode(INVALID_LIMIT_CLAUSE); + } + + @Test + public void testIllegalClausesInRecursiveTerm() + { + // recursive reference in inner source of outer join + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM (SELECT 10) u LEFT JOIN t ON true WHERE n < 6" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:114: recursive reference in right source of LEFT join"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t RIGHT JOIN (SELECT 10) u ON true WHERE n < 6" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:90: recursive reference in left source of RIGHT join"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t FULL JOIN (SELECT 10) u ON true WHERE n < 6" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:90: recursive reference in left source of FULL join"); + + // recursive reference in INTERSECT + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((SELECT 10) INTERSECT ALL (TABLE t)) u(n))" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:119: recursive reference in INTERSECT ALL"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((TABLE t) INTERSECT ALL (SELECT 10)) u(n))" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:93: recursive reference in INTERSECT ALL"); + + // recursive reference in EXCEPT + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((SELECT 10) EXCEPT (TABLE t)) u(n))" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:112: recursive reference in right relation of EXCEPT DISTINCT"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((SELECT 10) EXCEPT ALL (TABLE t)) u(n))" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:116: recursive reference in right relation of EXCEPT ALL"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((TABLE t) EXCEPT ALL (SELECT 10)) u(n))" + + " )" + + " SELECT * FROM t") + .hasErrorCode(INVALID_RECURSIVE_REFERENCE) + .hasMessage("line 1:93: recursive reference in left relation of EXCEPT ALL"); + } + + @Test + public void testRecursiveReferenceShadowing() + { + // table 't' in subquery refers to WITH-query defined in subquery, so it is not a recursive reference to 't' in the top-level WITH-list + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t(m) AS (SELECT 4) SELECT n + 1 FROM t)" + + " )" + + " SELECT * from t") + .hasErrorCode(COLUMN_NOT_FOUND); + + // table 't' in subquery refers to WITH-query defined in subquery, so it is not a recursive reference to 't' in the top-level WITH-list + // the top-level WITH-query is effectively not recursive + analyze("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t(n) AS (SELECT 4) SELECT n + 1 FROM t)" + + " )" + + " SELECT * from t"); + + // the inner WITH-clause does not define a table with conflicting name 't'. Recursive reference is found in the subquery + analyze("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t2(m) AS (SELECT 4) SELECT m FROM t2 UNION SELECT n + 1 FROM t) t(n) WHERE n < 4" + + " )" + + " SELECT * from t"); + + // the inner WITH-clause defines a table with conflicting name 't'. Recursive reference in the subquery is not found even though it is before the point of shadowing + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t2(m) AS (TABLE t), t(p) AS (SELECT 1) SELECT m + 1 FROM t2) t(n) WHERE n < 4" + + " )" + + " SELECT * from t") + .hasErrorCode(TABLE_NOT_FOUND); + } + + @Test + public void testWithRecursiveUncoercibleTypes() + { + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT BIGINT '9' FROM t WHERE n < 7" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:72: recursion step relation output type (bigint) is not coercible to recursion base relation output type (integer) at column 1"); + + assertFails("WITH RECURSIVE t(n, m, p) AS (" + + " SELECT * FROM (VALUES(1, 2, 3))" + + " UNION ALL" + + " SELECT n + 1, BIGINT '9', BIGINT '9' FROM t WHERE n < 7" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:101: recursion step relation output type (bigint) is not coercible to recursion base relation output type (integer) at column 2"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT DECIMAL '1'" + + " UNION ALL" + + " SELECT n * 0.9 FROM t WHERE n > 0.7" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:82: recursion step relation output type (decimal(2,1)) is not coercible to recursion base relation output type (decimal(1,0)) at column 1"); + + assertFails("WITH RECURSIVE t(n) AS (" + + " SELECT * FROM (VALUES('a'), ('b')) AS T(n)" + + " UNION ALL" + + " SELECT n || 'x' FROM t WHERE n < 'axxxx'" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:106: recursion step relation output type (varchar) is not coercible to recursion base relation output type (varchar(1)) at column 1"); + + assertFails("WITH RECURSIVE t(n, m, o) AS (" + + " SELECT * FROM (VALUES(1, 2, ROW('a', 4)), (5, 6, ROW('a', 8)))" + + " UNION ALL" + + " SELECT t.o.*, ROW('a', 10) FROM t WHERE m < 3" + + " )" + + " SELECT * from t") + .hasErrorCode(TYPE_MISMATCH) + .hasMessage("line 1:132: recursion step relation output type (varchar(1)) is not coercible to recursion base relation output type (integer) at column 1"); + } + @Test public void testExpressions() { @@ -2242,7 +2727,8 @@ private void analyze(Session clientSession, @Language("SQL") String query) .readUncommitted() .execute(clientSession, session -> { Analyzer analyzer = createAnalyzer(session, metadata); - Statement statement = SQL_PARSER.createStatement(query, new ParsingOptions()); + Statement statement = SQL_PARSER.createStatement(query, new ParsingOptions( + new FeaturesConfig().isParseDecimalLiteralsAsDouble() ? AS_DOUBLE : AS_DECIMAL)); analyzer.analyze(statement); }); } diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java index d73be36a35a6..73e19c7ec27e 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java @@ -106,6 +106,7 @@ public void testDefaults() .setArrayAggGroupImplementation(ArrayAggGroupImplementation.NEW) .setMultimapAggGroupImplementation(MultimapAggGroupImplementation.NEW) .setDistributedSortEnabled(true) + .setMaxRecursionDepth(10) .setMaxGroupingSets(2048) .setLateMaterializationEnabled(false) .setSkipRedundantSort(true) @@ -183,6 +184,7 @@ public void testExplicitPropertyMappings() .put("optimizer.prefer-partial-aggregation", "false") .put("optimizer.optimize-top-n-row-number", "false") .put("distributed-sort", "false") + .put("max-recursion-depth", "8") .put("analyzer.max-grouping-sets", "2047") .put("experimental.late-materialization.enabled", "true") .put("optimizer.skip-redundant-sort", "false") @@ -256,6 +258,7 @@ public void testExplicitPropertyMappings() .setArrayAggGroupImplementation(ArrayAggGroupImplementation.LEGACY) .setMultimapAggGroupImplementation(MultimapAggGroupImplementation.LEGACY) .setDistributedSortEnabled(false) + .setMaxRecursionDepth(8) .setMaxGroupingSets(2047) .setDefaultFilterFactorEnabled(true) .setLateMaterializationEnabled(true) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestRecursiveCTE.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestRecursiveCTE.java new file mode 100644 index 000000000000..008defb607ef --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestRecursiveCTE.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.Session; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.testing.LocalQueryRunner; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.LogicalPlanner.Stage.CREATED; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.functionCall; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.union; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.window; +import static io.prestosql.testing.TestingSession.testSessionBuilder; + +public class TestRecursiveCTE + extends BasePlanTest +{ + @Override + protected LocalQueryRunner createLocalQueryRunner() + { + Session.SessionBuilder sessionBuilder = testSessionBuilder() + .setSystemProperty("max_recursion_depth", "1"); + + return LocalQueryRunner.create(sessionBuilder.build()); + } + + @Test + public void testRecursiveQuery() + { + @Language("SQL") String sql = "WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t"; + + PlanMatchPattern pattern = + anyTree( + union( + // base term + project(project(project( + ImmutableMap.of("expr", expression("1")), + values()))), + // first recursion step + project(project(project( + ImmutableMap.of("expr_0", expression("expr + 2")), + filter( + "expr < 6", + project(project(project( + ImmutableMap.of("expr", expression("1")), + values()))))))), + // "post-recursion" step with convergence assertion + filter( + "IF((count >= BIGINT '0'), " + + "CAST(fail(CAST('Recursion depth limit exceeded (1). Use ''max_recursion_depth'' session property to modify the limit.' AS varchar)) AS boolean), " + + "true)", + window(windowBuilder -> windowBuilder + .addFunction( + "count", + functionCall("count", ImmutableList.of())), + project(project(project( + ImmutableMap.of("expr_1", expression("expr + 2")), + filter( + "expr < 6", + project( + ImmutableMap.of("expr", expression("expr_0")), + project(project(project( + ImmutableMap.of("expr_0", expression("expr + 2")), + filter( + "expr < 6", + project(project(project( + ImmutableMap.of("expr", expression("1")), + values())))))))))))))))); + + assertPlan(sql, CREATED, pattern); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/query/TestRecursiveCTE.java b/presto-main/src/test/java/io/prestosql/sql/query/TestRecursiveCTE.java new file mode 100644 index 000000000000..8423dacec1b8 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/query/TestRecursiveCTE.java @@ -0,0 +1,267 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.query; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static io.prestosql.SystemSessionProperties.getMaxRecursionDepth; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRecursiveCTE +{ + private QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testSimpleRecursion() + { + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t WHERE n < 6" + + " )" + + " SELECT * from t")) + .matches("VALUES (1), (3), (5), (7)"); + + assertThat(assertions.query("WITH RECURSIVE t(n, m) AS (" + + " SELECT * FROM (VALUES(1, 2), (4, 100))" + + " UNION ALL" + + " SELECT n + 1, m - 1 FROM t WHERE n < 5" + + " )" + + " SELECT * from t")) + .matches("VALUES (1, 2), (4, 100), (2, 1), (5, 99), (3, 0), (4, -1), (5, -2)"); + + assertThat(assertions.query("WITH RECURSIVE t(n, m, o) AS (" + + " SELECT * FROM (VALUES(1, 2, ROW(3, 4)), (5, 6, ROW(7, 8)))" + + " UNION ALL" + + " SELECT t.o.*, ROW(10, 10) FROM t WHERE m < 3" + + " )" + + " SELECT * from t")) + .matches("VALUES (1, 2, ROW(3, 4)), (5, 6, ROW(7, 8)), (3, 4, ROW(10, 10))"); + } + + @Test + public void testUnionDistinct() + { + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT * FROM (VALUES(1), (1), (10))" + + " UNION" + + " SELECT n + 2 FROM t WHERE n < 4" + + " )" + + " SELECT * from t")) + .matches("VALUES (1), (10), (3), (5)"); + + assertThat(assertions.query("WITH RECURSIVE t(n, m) AS (" + + " SELECT * FROM (VALUES(1, 2), (2, 3))" + + " UNION" + + " SELECT n + 1, m + 1 FROM t WHERE n < 3" + + " )" + + " SELECT * from t")) + .matches("VALUES (1, 2), (2, 3), (3, 4)"); + } + + @Test + public void testNestedWith() + { + // recursive reference visible in subquery containing WITH + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t2(m) AS (SELECT 4) SELECT m FROM t2 UNION SELECT n + 1 FROM t) t(n) WHERE n < 4" + + " )" + + " SELECT * from t")) + .matches("VALUES (1), (2), (3)"); + + // recursive reference shadowed by WITH in subquery. The query is effectively not recursive + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t(n) AS (SELECT 5) SELECT n + 1 FROM t)" + + " )" + + " SELECT * from t")) + .matches("VALUES (1), (6)"); + + // multiple nesting + assertThat(assertions.query("WITH t(n) AS (" + + " WITH t2(m) AS (" + + " WITH RECURSIVE t3(p) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM (WITH t4(q) AS (SELECT 4) SELECT p + 1 FROM t3 WHERE p < 3)" + + " )" + + " SELECT * from t3" + + " )" + + " SELECT * FROM t2" + + " )" + + " SELECT * FROM t")) + .matches("VALUES (1), (2), (3)"); + } + + @Test + public void testMultipleWithListEntries() + { + // second and third WITH-queries are recursive + assertThat(assertions.query("WITH RECURSIVE a(x) AS (SELECT 1)," + + " b(y) AS (" + + " SELECT x FROM a" + + " UNION ALL" + + " SELECT y + 1 FROM b WHERE y < 2" + + " )," + + " c(z) AS (" + + " SELECT y FROM b" + + " UNION ALL" + + " SELECT z * 4 FROM c WHERE z < 4" + + " )" + + " SELECT * FROM a, b, c")) + .matches("VALUES " + + "(1, 1, 1), " + + "(1, 1, 2), " + + "(1, 1, 4), " + + "(1, 1, 8), " + + "(1, 2, 1), " + + "(1, 2, 2), " + + "(1, 2, 4), " + + "(1, 2, 8)"); + } + + @Test + public void testVarchar() + { + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT CAST(n AS varchar) FROM (VALUES('a'), ('b')) AS T(n)" + + " UNION ALL" + + " SELECT n || 'x' FROM t WHERE n < 'axx'" + + " )" + + " SELECT * from t")) + .matches("VALUES (varchar 'a'), (varchar 'b'), (varchar 'ax'), (varchar 'axx')"); + } + + @Test + public void testTypeCoercion() + { + // integer result of step relation coerced to bigint + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT BIGINT '1'" + + " UNION ALL" + + " SELECT CAST(n + 1 AS integer) FROM t WHERE n < 3" + + " )" + + " SELECT * from t")) + .matches("VALUES (BIGINT '1'), (BIGINT '2'), (BIGINT '3')"); + + // result of step relation coerced from decimal(10,0) to decimal(20,10) + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT CAST(1 AS decimal(20,10))" + + " UNION ALL" + + " SELECT CAST(n + 1 AS decimal(10,0)) FROM t WHERE n < 2" + + " )" + + " SELECT * from t")) + .matches("VALUES (CAST(1 AS decimal(20,10))), (CAST(2 AS decimal(20,10)))"); + + // result of step relation coerced from varchar(5) to varchar + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT CAST('ABCDE' AS varchar)" + + " UNION ALL" + + " SELECT CAST(substr(n, 2) AS varchar(5)) FROM t WHERE n < 'E'" + + " )" + + " SELECT * from t")) + .matches("VALUES (CAST('ABCDE' AS varchar)), (CAST('BCDE' AS varchar)), (CAST('CDE' AS varchar)), (CAST('DE' AS varchar)), (CAST('E' AS varchar))"); + + //multiple coercions + assertThat(assertions.query("WITH RECURSIVE t(n, m) AS (" + + " SELECT BIGINT '1', INTEGER '2'" + + " UNION ALL" + + " SELECT CAST(n + 1 AS tinyint), CAST(m + 2 AS smallint) FROM t WHERE n < 3" + + " )" + + " SELECT * from t")) + .matches("VALUES " + + "(BIGINT '1', INTEGER '2'), " + + "(BIGINT '2', INTEGER '4'), " + + "(BIGINT '3', INTEGER '6')"); + } + + @Test + public void testJoin() + { + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT b + 1 FROM ((SELECT 5) JOIN t ON true) t(a, b) WHERE b < 3" + + " )" + + " SELECT * from t")) + .matches("VALUES (1), (2), (3)"); + + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM (SELECT 10) u RIGHT JOIN t ON true WHERE n < 6" + + " )" + + " SELECT * FROM t")) + .matches("VALUES (1), (3), (5), (7)"); + + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT n + 2 FROM t LEFT JOIN (SELECT 10) u ON true WHERE n < 6" + + " )" + + " SELECT * FROM t")) + .matches("VALUES (1), (3), (5), (7)"); + } + + @Test + public void testSetOperation() + { + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((TABLE t) INTERSECT DISTINCT (SELECT 1)) u(n))" + + " )" + + " SELECT * FROM t")) + .matches("VALUES (1), (3)"); + + assertThat(assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " (SELECT n + 2 FROM ((TABLE t) EXCEPT DISTINCT (SELECT 10)) u(n) WHERE n < 3)" + + " )" + + " SELECT * FROM t")) + .matches("VALUES (1), (3)"); + } + + @Test + public void testRecursionDepthLimitExceeded() + { + assertThatThrownBy(() -> assertions.query("WITH RECURSIVE t(n) AS (" + + " SELECT 1" + + " UNION ALL" + + " SELECT * FROM t" + + " )" + + " SELECT * FROM t")) + .hasMessage("Recursion depth limit exceeded (%s). Use 'max_recursion_depth' session property to modify the limit.", getMaxRecursionDepth(assertions.getDefaultSession())); + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/StandardErrorCode.java b/presto-spi/src/main/java/io/prestosql/spi/StandardErrorCode.java index 68cc5a1c4d28..07298b02b857 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/StandardErrorCode.java +++ b/presto-spi/src/main/java/io/prestosql/spi/StandardErrorCode.java @@ -107,6 +107,10 @@ public enum StandardErrorCode INVALID_ROW_FILTER(84, USER_ERROR), INVALID_COLUMN_MASK(85, USER_ERROR), MISSING_TABLE(86, USER_ERROR), + INVALID_RECURSIVE_REFERENCE(87, USER_ERROR), + MISSING_COLUMN_ALIASES(88, USER_ERROR), + NESTED_RECURSIVE(89, USER_ERROR), + INVALID_LIMIT_CLAUSE(90, USER_ERROR), GENERIC_INTERNAL_ERROR(65536, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(65537, INTERNAL_ERROR), diff --git a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java index e9c95b23c7d9..e4a1a18377b2 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java @@ -2371,9 +2371,7 @@ public void testWithHiding() @Test public void testWithRecursive() { - assertQueryFails( - "WITH RECURSIVE a AS (SELECT 123) SELECT * FROM a", - "line 1:1: Recursive WITH queries are not supported"); + assertQuery("WITH RECURSIVE a(x) AS (SELECT 123) SELECT * FROM a"); } @Test