From 09668946252bb27afddc39a24bdb03c09802cf83 Mon Sep 17 00:00:00 2001 From: Zhenxiao Luo Date: Wed, 13 Mar 2019 23:59:19 -0700 Subject: [PATCH] Pushdown Dereference Expressions --- .../presto/sql/planner/PlanOptimizers.java | 8 + .../iterative/rule/PushDownDereferences.java | 471 ++++++++++++++++++ .../presto/sql/planner/plan/Assignments.java | 4 +- .../sql/planner/TestDereferencePushDown.java | 122 +++++ .../assertions/ExpressionVerifier.java | 40 ++ 5 files changed, 643 insertions(+), 2 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index ae5418ef4cda..983e4542783d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -70,6 +70,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneWindowColumns; import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; +import com.facebook.presto.sql.planner.iterative.rule.PushDownDereferences; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughOuterJoin; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughProject; @@ -344,6 +345,13 @@ public PlanOptimizers( new TransformCorrelatedSingleRowSubqueryToProject())), new CheckSubqueryNodesAreRewritten(), predicatePushDown, + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.>builder() + .addAll(new PushDownDereferences(metadata, sqlParser).rules()) + .build()), new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java new file mode 100644 index 000000000000..4f37b30f8e72 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java @@ -0,0 +1,471 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.iterative.Rule.Context; +import com.facebook.presto.sql.planner.iterative.Rule.Result; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SortNode; +import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractAll; +import static com.facebook.presto.sql.planner.plan.Patterns.join; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.sort; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.planner.plan.Patterns.unnest; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; + +public class PushDownDereferences +{ + private final Metadata metadata; + private final SqlParser sqlParser; + + public PushDownDereferences(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + public Set> rules() + { + return ImmutableSet.of( + new PushDownDereferenceThroughJoin(metadata, sqlParser), + new PushDownDereferenceThroughSort(metadata, sqlParser), + new PushDownDereferenceThroughUnnest(metadata, sqlParser), + new PushDownDereferenceThroughProject(metadata, sqlParser)); + } + + private abstract class DereferencePushDownRule + implements Rule + { + private final Capture targetCapture = newCapture(); + private final Pattern targetPattern; + + protected final Metadata metadata; + protected final SqlParser sqlParser; + + protected DereferencePushDownRule(Metadata metadata, SqlParser sqlParser, Pattern targetPattern) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.targetPattern = requireNonNull(targetPattern, "targetPattern is null"); + } + + @Override + public Pattern getPattern() + { + return project().with(source().matching(targetPattern.capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + N child = captures.get(targetCapture); + Map expressions = getDereferenceSymbolMap(node.getAssignments().getExpressions(), context, metadata, sqlParser); + Assignments assignments = node.getAssignments().rewrite(new DereferenceReplacer(expressions)); + + Result result = pushDownDereferences(context, child, expressions, assignments); + if (result.isEmpty()) { + return Result.empty(); + } + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), result.getTransformedPlan().get(), assignments)); + } + + protected abstract Result pushDownDereferences(Context context, N targetNode, Map expressions, Assignments assignments); + + protected boolean validPushDown(DereferenceExpression dereference) + { + Expression base = dereference.getBase(); + return (base instanceof SymbolReference) || (base instanceof DereferenceExpression); + } + } + + /** Transforms: + *
+     *  Project(a_x := a.msg.x)
+     *    Join(a_y = b_y) => [a]
+     *      Project(a_y := a.msg.y)
+     *          Source(a)
+     *      Project(b_y := b.msg.y)
+     *          Source(b)
+     *  
+ * to: + *
+     *  Project(a_x := a_x)
+     *    Join(a_y = b_y) => [a_x]
+     *      Project(a_x := a.msg.x, a_y := a.msg.y)
+     *        Source(a)
+     *      Project(b_y := b.msg.y)
+     *        Source(b)
+     * 
+ */ + public class PushDownDereferenceThroughJoin + extends DereferencePushDownRule + { + public PushDownDereferenceThroughJoin(Metadata metadata, SqlParser sqlParser) + { + super(metadata, sqlParser, join()); + } + + @Override + protected Result pushDownDereferences(Context context, JoinNode joinNode, Map expressions, Assignments assignments) + { + List outputSymbols = joinNode.getOutputSymbols(); + Map projectExpressions = expressions.entrySet().stream() + .filter(entry -> validPushDown(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + ImmutableMap.Builder dereferenceSymbolsBuilder = ImmutableMap.builder(); + dereferenceSymbolsBuilder.putAll(expressions); + if (joinNode.getFilter().isPresent()) { + Map predicateSymbols = getDereferenceSymbolMap(ImmutableList.of(joinNode.getFilter().get()), context, metadata, sqlParser).entrySet().stream() + .filter(entry -> !projectExpressions.values().contains(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + dereferenceSymbolsBuilder.putAll(predicateSymbols); + } + Map dereferenceSymbols = dereferenceSymbolsBuilder.build(); + + Map dereferences = dereferenceSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + ImmutableMap.Builder pushdownExpressionsBuilder = ImmutableMap.builder(); + pushdownExpressionsBuilder.putAll(dereferences); + Map remainingProjectExpressions = projectExpressions.entrySet().stream() + .filter(entry -> !dereferences.keySet().contains(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + pushdownExpressionsBuilder.putAll(remainingProjectExpressions); + + Map pushdownExpressions = pushdownExpressionsBuilder.build(); + if (pushdownExpressions.isEmpty()) { + return Result.empty(); + } + + Map symbolsMap = pushdownExpressions.entrySet().stream() + .collect(toImmutableMap(entry -> getOnlyElement(extractAll(entry.getValue())), Map.Entry::getKey)); + + PlanNode left = joinNode.getLeft(); + PlanNode right = joinNode.getRight(); + + Assignments.Builder leftBuilder = Assignments.builder(); + leftBuilder.putIdentities(left.getOutputSymbols().stream() + .filter(symbol -> !symbolsMap.containsKey(symbol)) + .collect(toImmutableList())); + + Assignments.Builder rightBuilder = Assignments.builder(); + rightBuilder.putIdentities(right.getOutputSymbols().stream() + .filter(symbol -> !symbolsMap.containsKey(symbol)) + .collect(toImmutableList())); + + for (Map.Entry entry : pushdownExpressions.entrySet()) { + Symbol outputSymbol = getOnlyElement(extractAll(entry.getValue())); + if (left.getOutputSymbols().contains(outputSymbol)) { + leftBuilder.put(entry.getKey(), entry.getValue()); + } + if (right.getOutputSymbols().contains(outputSymbol)) { + rightBuilder.put(entry.getKey(), entry.getValue()); + } + } + ProjectNode leftChild = new ProjectNode(context.getIdAllocator().getNextId(), left, leftBuilder.build()); + ProjectNode rightChild = new ProjectNode(context.getIdAllocator().getNextId(), right, rightBuilder.build()); + + return Result.ofPlanNode( + new JoinNode( + context.getIdAllocator().getNextId(), + joinNode.getType(), + leftChild, + rightChild, + joinNode.getCriteria(), + ImmutableList.builder() + .addAll(leftChild.getOutputSymbols()) + .addAll(rightChild.getOutputSymbols()) + .build(), + joinNode.getFilter().map(expression -> replaceDereferences(expression, dereferenceSymbols)), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType())); + } + } + + /** Transforms: + *
+     *  Project(a_x := a.msg.x)
+     *    Project(a_y := key)
+     *          Source(a)
+     *  
+ * to: + *
+     *  Project(a_y := key, a_z = a.msg.x)
+     *    Source(a)
+     * 
+ */ + public class PushDownDereferenceThroughProject + extends DereferencePushDownRule + { + public PushDownDereferenceThroughProject(Metadata metadata, SqlParser sqlParser) + { + super(metadata, sqlParser, project()); + } + + @Override + protected Result pushDownDereferences(Context context, ProjectNode projectNode, Map expressions, Assignments assignments) + { + List outputSymbols = projectNode.getOutputSymbols(); + Map pushdownExpressions = expressions.entrySet().stream() + .filter(entry -> validPushDown(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + if (pushdownExpressions.isEmpty()) { + return Result.empty(); + } + + Map symbolsMap = pushdownExpressions.entrySet().stream() + .collect(toImmutableMap(entry -> getOnlyElement(extractAll(entry.getValue())), Map.Entry::getKey)); + + Assignments.Builder sourceBuilder = Assignments.builder(); + for (Map.Entry entry : projectNode.getAssignments().entrySet()) { + if (symbolsMap.containsKey(entry.getKey())) { + Symbol targetSymbol = symbolsMap.get(entry.getKey()); + DereferenceExpression targetDereference = (DereferenceExpression) pushdownExpressions.get(targetSymbol); + DereferenceExpression dereference = new DereferenceExpression(entry.getValue(), targetDereference.getField()); + sourceBuilder.put(targetSymbol, dereference); + } + else { + sourceBuilder.put(entry.getKey(), entry.getValue()); + } + } + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), projectNode.getSource(), sourceBuilder.build())); + } + } + + /** + * Transforms: + *
+     *  Project(a_x := a.msg.x)
+     *    Sort
+     *      Source(a)
+     *  
+ * to: + *
+     *  Sort
+     *    Project(a_y := a.msg.x)
+     *      Source(a)
+     *  
+ */ + public class PushDownDereferenceThroughSort + extends DereferencePushDownRule + { + public PushDownDereferenceThroughSort(Metadata metadata, SqlParser sqlParser) + { + super(metadata, sqlParser, sort()); + } + + @Override + protected Result pushDownDereferences(Context context, SortNode sortNode, Map expressions, Assignments assignments) + { + List outputSymbols = sortNode.getOutputSymbols(); + Map pushdownExpressions = expressions.entrySet().stream() + .filter(entry -> validPushDown(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + if (pushdownExpressions.isEmpty()) { + return Result.empty(); + } + + Assignments newAssignments = Assignments.builder() + .putAll(pushdownExpressions) + .putIdentities(outputSymbols) + .build(); + ProjectNode source = new ProjectNode(context.getIdAllocator().getNextId(), sortNode.getSource(), newAssignments); + SortNode result = new SortNode(context.getIdAllocator().getNextId(), source, sortNode.getOrderingScheme()); + return Result.ofPlanNode(result); + } + } + + /** + * Transforms: + *
+     *  Project(a_x := a.msg.x)
+     *    Unnest
+     *      Source(a)
+     *  
+ * to: + *
+     *  Unnest
+     *    Project(a_y := a.msg.x)
+     *      Source(a)
+     *  
+ */ + public class PushDownDereferenceThroughUnnest + extends DereferencePushDownRule + { + public PushDownDereferenceThroughUnnest(Metadata metadata, SqlParser sqlParser) + { + super(metadata, sqlParser, unnest()); + } + + @Override + protected Result pushDownDereferences(Context context, UnnestNode unnestNode, Map expressions, Assignments assignments) + { + List outputSymbols = unnestNode.getOutputSymbols(); + Map pushdownExpressions = expressions.entrySet().stream() + .filter(entry -> validPushDown(entry.getKey())) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + if (pushdownExpressions.isEmpty()) { + return Result.empty(); + } + + ImmutableMap.Builder symbolsMapBuilder = ImmutableMap.builder(); + for (Map.Entry entry : expressions.entrySet()) { + Expression expression = entry.getKey().getBase(); + if (expression instanceof SymbolReference) { + symbolsMapBuilder.put(Symbol.from(expression), entry.getValue()); + } + } + Map symbolsMap = symbolsMapBuilder.build(); + + List sourceSymbols = unnestNode.getSource().getOutputSymbols().stream() + .filter(symbol -> !symbolsMap.containsKey(symbol)) + .collect(toImmutableList()); + + List relicateSymbols = unnestNode.getReplicateSymbols().stream() + .map(symbol -> replaceSymbol(symbolsMap, symbol)) + .collect(toImmutableList()); + + Assignments newAssignments = Assignments.builder() + .putAll(pushdownExpressions) + .putIdentities(sourceSymbols) + .build(); + ProjectNode source = new ProjectNode(context.getIdAllocator().getNextId(), unnestNode.getSource(), newAssignments); + UnnestNode result = new UnnestNode(context.getIdAllocator().getNextId(), source, relicateSymbols, unnestNode.getUnnestSymbols(), unnestNode.getOrdinalitySymbol()); + return Result.ofPlanNode(result); + } + } + + private static Symbol replaceSymbol(Map symbolsMap, Symbol symbol) + { + if (symbolsMap.containsKey(symbol)) { + return symbolsMap.get(symbol); + } + return symbol; + } + + private static Expression replaceDereferences(Expression expression, Map replacements) + { + return ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(replacements), expression); + } + + private static class DereferenceReplacer + extends ExpressionRewriter + { + private final Map expressions; + + DereferenceReplacer(Map expressions) + { + this.expressions = requireNonNull(expressions, "expressions is null"); + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (expressions.containsKey(node)) { + return expressions.get(node).toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + } + + private static List extractDereferenceExpressions(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + context.add(node); + process(node.getBase(), context); + return null; + } + }.process(expression, builder); + return builder.build(); + } + + private static Map getDereferenceSymbolMap(Collection expressions, Context context, Metadata metadata, SqlParser sqlParser) + { + Set dereferences = expressions.stream() + .flatMap(expression -> extractDereferenceExpressions(expression).stream()) + .collect(toImmutableSet()); + + return dereferences.stream() + .filter(expression -> !baseExists(expression, dereferences)) + .collect(toImmutableMap(Function.identity(), expression -> newSymbol(expression, context, metadata, sqlParser))); + } + + private static Symbol newSymbol(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) + { + Type type = getExpressionTypes(context.getSession(), metadata, sqlParser, context.getSymbolAllocator().getTypes(), expression, emptyList(), WarningCollector.NOOP).get(NodeRef.of(expression)); + verify(type != null); + return context.getSymbolAllocator().newSymbol(expression, type); + } + + private static boolean baseExists(DereferenceExpression expression, Set dereferences) + { + Expression base = expression.getBase(); + while (base instanceof DereferenceExpression) { + if (dereferences.contains(base)) { + return true; + } + base = ((DereferenceExpression) base).getBase(); + } + return false; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 749613cd80d8..d99615f76def 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -205,9 +205,9 @@ public Builder putAll(Assignments assignments) return putAll(assignments.getMap()); } - public Builder putAll(Map assignments) + public Builder putAll(Map assignments) { - for (Entry assignment : assignments.entrySet()) { + for (Entry assignment : assignments.entrySet()) { put(assignment.getKey(), assignment.getValue()); } return this; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java new file mode 100644 index 000000000000..da27a341f4f0 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.Ordering; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestDereferencePushDown + extends BasePlanTest +{ + @Test + public void testDereferencePushdownJoin() + { + assertPlan("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT b.msg.x FROM t a, t b WHERE a.msg.y = b.msg.y", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + values("msg")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT a.msg.y FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x > bigint '5'", + output(ImmutableList.of("a_y"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg"))) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT b.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x + b.msg.x < bigint '10'", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownSort() + { + ImmutableList orderBy = ImmutableList.of(sort("b_x", ASCENDING, LAST)); + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT a.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x < bigint '10' ORDER BY b.msg.x", + output(ImmutableList.of("expr"), + project(ImmutableMap.of("expr", expression("a_x")), + exchange(LOCAL, GATHER, orderBy, + sort(orderBy, + exchange(LOCAL, REPARTITION, + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + filter("msg.x < bigint '10'", + values("msg"))) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))))))); + } + + @Test + public void testDereferencePushdownUnnest() + { + assertPlan("WITH t(msg, array) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)), ARRAY[1, 2, 3]))) " + + "SELECT a.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y CROSS JOIN UNNEST (a.array) WHERE a.msg.x + b.msg.x < bigint '10'", + output(ImmutableList.of("expr"), + project(ImmutableMap.of("expr", expression("a_x")), + unnest( + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x"), "a_z", expression("array")), + values("msg", "array")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index 81aae0689f84..9d8525891c1e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -21,6 +21,7 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DecimalLiteral; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -34,8 +35,10 @@ import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.SearchedCaseExpression; import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; @@ -411,6 +414,43 @@ protected Boolean visitInListExpression(InListExpression actual, Node expected) return process(actual.getValues(), expectedInList.getValues()); } + @Override + protected Boolean visitDereferenceExpression(DereferenceExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof DereferenceExpression)) { + return false; + } + + DereferenceExpression expected = (DereferenceExpression) expectedExpression; + if (actual.getField().equals(expected.getField())) { + return process(actual.getBase(), expected.getBase()); + } + return false; + } + + @Override + protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SubscriptExpression)) { + return false; + } + + SubscriptExpression expected = (SubscriptExpression) expectedExpression; + + return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + } + + @Override + protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SearchedCaseExpression)) { + return false; + } + + SearchedCaseExpression expected = (SearchedCaseExpression) expectedExpression; + return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses()); + } + private boolean process(List actuals, List expecteds) { if (actuals.size() != expecteds.size()) {