diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 7fdc6ed5ebfec..e6f992bf28c67 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -696,6 +696,7 @@ public Optional visitTopNRowNumber(TopNRowNumberNode node, Context con new WindowNode.Specification( partitionBy, node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))), + node.getRankingFunction(), rowNumberVariable, node.getMaxRowCountPerPartition(), node.isPartial(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index fd4dec41e7777..3e07991e26764 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -641,7 +641,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionAndTypeManager()))), new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown - new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits + new WindowFilterPushDown(metadata, featuresConfig.isNativeExecutionEnabled()), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits prefilterForLimitingAggregation, new IterativeOptimizer( metadata, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 6e96c4b29c2c9..dacd2c68f2762 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -493,6 +493,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr idAllocator.getNextId(), child.getNode(), node.getSpecification(), + node.getRankingFunction(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), true, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 995a687c98b2c..cb7de4ce84ec6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -329,6 +329,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa node.getId(), child.getNode(), node.getSpecification(), + node.getRankingFunction(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.isPartial(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 9110ec4ceb128..54bbb980bcd3f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -310,6 +310,7 @@ public Optional visitTopN(TopNNode node, Void context) new Specification( ImmutableList.copyOf(childDecorrelationResult.variablesToPropagate), Optional.of(orderingScheme)), + TopNRowNumberNode.RankingFunction.ROW_NUMBER, variableAllocator.newVariable("row_number", BIGINT), toIntExact(node.getCount()), false, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index fe8c1e11ec756..a39cc80cae745 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -701,6 +701,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getSpecification()), + node.getRankingFunction(), canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), node.isPartial(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java index 337e9eb39df03..2cbcbabd2a03f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java @@ -46,6 +46,7 @@ import java.util.Optional; import java.util.OptionalInt; +import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled; import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRowNumber; import static com.facebook.presto.common.predicate.Marker.Bound.BELOW; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -65,10 +66,13 @@ public class WindowFilterPushDown private final RowExpressionDomainTranslator domainTranslator; private final LogicalRowExpressions logicalRowExpressions; - public WindowFilterPushDown(Metadata metadata) + private boolean isNativeExecution = false; + + public WindowFilterPushDown(Metadata metadata, boolean isNativeExecution) { this.metadata = requireNonNull(metadata, "metadata is null"); this.domainTranslator = new RowExpressionDomainTranslator(metadata); + this.isNativeExecution = isNativeExecution; this.logicalRowExpressions = new LogicalRowExpressions( new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()), new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()), @@ -84,7 +88,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider requireNonNull(variableAllocator, "variableAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - Rewriter rewriter = new Rewriter(idAllocator, metadata, domainTranslator, logicalRowExpressions, session); + Rewriter rewriter = new Rewriter(idAllocator, metadata, domainTranslator, logicalRowExpressions, session, isNativeExecution); PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); } @@ -97,15 +101,17 @@ private static class Rewriter private final RowExpressionDomainTranslator domainTranslator; private final LogicalRowExpressions logicalRowExpressions; private final Session session; + private final boolean isNativeExecution; private boolean planChanged; - private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, RowExpressionDomainTranslator domainTranslator, LogicalRowExpressions logicalRowExpressions, Session session) + private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, RowExpressionDomainTranslator domainTranslator, LogicalRowExpressions logicalRowExpressions, Session session, boolean isNativeExecution) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null"); this.logicalRowExpressions = logicalRowExpressions; this.session = requireNonNull(session, "session is null"); + this.isNativeExecution = isNativeExecution; } public boolean isPlanChanged() @@ -138,6 +144,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) public PlanNode visitLimit(LimitNode node, RewriteContext context) { // Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value + // TODO (Aditi) : Don't think this check is needed for Native engine. if (node.getCount() > Integer.MAX_VALUE) { return context.defaultRewrite(node); } @@ -152,11 +159,11 @@ public PlanNode visitLimit(LimitNode node, RewriteContext context) planChanged = true; source = rowNumberNode; } - else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) { + else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager(), isNativeExecution) && isOptimizeTopNRowNumber(session)) { WindowNode windowNode = (WindowNode) source; // verify that unordered row_number window functions are replaced by RowNumberNode verify(windowNode.getOrderingScheme().isPresent()); - TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit); + TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit, metadata.getFunctionAndTypeManager()); if (windowNode.getPartitionBy().isEmpty()) { return topNRowNumberNode; } @@ -183,13 +190,13 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); } } - else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) { + else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager(), isNativeExecution) && isOptimizeTopNRowNumber(session)) { WindowNode windowNode = (WindowNode) source; VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable()); OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable); if (upperBound.isPresent()) { - source = convertToTopNRowNumber(windowNode, upperBound.getAsInt()); + source = convertToTopNRowNumber(windowNode, upperBound.getAsInt(), metadata.getFunctionAndTypeManager()); planChanged = true; return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); } @@ -273,13 +280,23 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa return new RowNumberNode(node.getSourceLocation(), node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberVariable(), Optional.of(newRowCountPerPartition), false, node.getHashVariable()); } - private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit) + private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit, FunctionAndTypeManager functionAndTypeManager) { + VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getWindowFunctions().keySet()); + FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(windowNode.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()); + + TopNRowNumberNode.RankingFunction rankingFunction = + isRowNumberMetadata(functionAndTypeManager, functionMetadata) ? + TopNRowNumberNode.RankingFunction.ROW_NUMBER : + isRankMetadata(functionAndTypeManager, functionMetadata) ? + TopNRowNumberNode.RankingFunction.RANK : + TopNRowNumberNode.RankingFunction.DENSE_RANK; return new TopNRowNumberNode( windowNode.getSourceLocation(), idAllocator.getNextId(), windowNode.getSource(), windowNode.getSpecification(), + rankingFunction, getOnlyElement(windowNode.getCreatedVariable()), limit, false, @@ -288,22 +305,49 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager) { - return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent(); + if (node.getWindowFunctions().size() != 1) { + return false; + } + VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet()); + + return isRowNumberMetadata(functionAndTypeManager, + functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle())) + && !node.getOrderingScheme().isPresent(); } - private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager) + private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager, boolean isNativeExecution) { if (node.getWindowFunctions().size() != 1) { return false; } VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet()); - return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle())); + FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()); + if (isNativeExecution) { + return isRowNumberMetadata(functionAndTypeManager, functionMetadata) + || node.getOrderingScheme().isPresent() && (isRankMetadata(functionAndTypeManager, functionMetadata) + || isDenseRankMetadata(functionAndTypeManager, functionMetadata)); + } + + return isRowNumberMetadata(functionAndTypeManager, functionMetadata); } + private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata) { FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of()); return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction)); } + + private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata) + { + FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of()); + return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction)); + } + + private static boolean isDenseRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata) + { + FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of()); + return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction)); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java index bc6ee14a5e81f..91dcc6d523b16 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java @@ -36,10 +36,18 @@ public final class TopNRowNumberNode extends InternalPlanNode { + public enum RankingFunction + { + ROW_NUMBER, + RANK, + DENSE_RANK + } + private final PlanNode source; private final Specification specification; + private final RankingFunction rankingFunction; private final VariableReferenceExpression rowNumberVariable; - private final int maxRowCountPerPartition; + private final int maxRankPerPartition; private final boolean partial; private final Optional hashVariable; @@ -49,12 +57,13 @@ public TopNRowNumberNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("specification") Specification specification, + @JsonProperty("rankingType") RankingFunction rankingFunction, @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable, - @JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition, + @JsonProperty("maxRowCountPerPartition") int maxRankPerPartition, @JsonProperty("partial") boolean partial, @JsonProperty("hashVariable") Optional hashVariable) { - this(sourceLocation, id, Optional.empty(), source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); + this(sourceLocation, id, Optional.empty(), source, specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable); } public TopNRowNumberNode( @@ -63,8 +72,9 @@ public TopNRowNumberNode( Optional statsEquivalentPlanNode, PlanNode source, Specification specification, + RankingFunction rankingFunction, VariableReferenceExpression rowNumberVariable, - int maxRowCountPerPartition, + int maxRankPerPartition, boolean partial, Optional hashVariable) { @@ -74,13 +84,14 @@ public TopNRowNumberNode( requireNonNull(specification, "specification is null"); checkArgument(specification.getOrderingScheme().isPresent(), "specification orderingScheme is absent"); requireNonNull(rowNumberVariable, "rowNumberVariable is null"); - checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); + checkArgument(maxRankPerPartition > 0, "maxRowCountPerPartition must be > 0"); requireNonNull(hashVariable, "hashVariable is null"); this.source = source; this.specification = specification; + this.rankingFunction = rankingFunction; this.rowNumberVariable = rowNumberVariable; - this.maxRowCountPerPartition = maxRowCountPerPartition; + this.maxRankPerPartition = maxRankPerPartition; this.partial = partial; this.hashVariable = hashVariable; } @@ -108,6 +119,12 @@ public PlanNode getSource() return source; } + @JsonProperty + public RankingFunction getRankingFunction() + { + return rankingFunction; + } + @JsonProperty public Specification getSpecification() { @@ -133,7 +150,7 @@ public VariableReferenceExpression getRowNumberVariable() @JsonProperty public int getMaxRowCountPerPartition() { - return maxRowCountPerPartition; + return maxRankPerPartition; } @JsonProperty @@ -157,12 +174,12 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); + return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); + return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable); } } diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index f1d880ebf1804..b0309bf6889f8 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -387,7 +387,8 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) { printNode(node, "TopNRowNumber", - format("partition by = %s|order by = %s|n = %s", + format("function = %s; partition by = %s|order by = %s|n = %s", + node.getRankingFunction(), Joiner.on(", ").join(node.getPartitionBy()), Joiner.on(", ").join(node.getOrderingScheme().getOrderByVariables()), node.getMaxRowCountPerPartition()), NODE_COLORS.get(NodeType.WINDOW)); diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index 94b7788496e8b..366287f76b35b 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -1577,6 +1577,22 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( toVeloxQueryPlan(node->source, tableWriteInfo, taskId)); } +namespace { +core::TopNRowNumberNode::RankFunction prestoToVeloxRankFunction( + protocol::RankingFunction rankingFunction) { + switch (rankingFunction) { + case protocol::RankingFunction::ROW_NUMBER: + return core::TopNRowNumberNode::RankFunction::kRowNumber; + case protocol::RankingFunction::RANK: + return core::TopNRowNumberNode::RankFunction::kRank; + case protocol::RankingFunction::DENSE_RANK: + return core::TopNRowNumberNode::RankFunction::kDenseRank; + default: + VELOX_UNREACHABLE(); + } +} +}; // namespace + std::shared_ptr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, @@ -1600,7 +1616,8 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( rowNumberColumnName = node->rowNumberVariable.name; } - if (sortFields.empty()) { + if (node->rankingType == protocol::RankingFunction::ROW_NUMBER && + sortFields.empty()) { // May happen if all sorting keys are also used as partition keys. return std::make_shared( @@ -1613,6 +1630,7 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( return std::make_shared( node->id, + prestoToVeloxRankFunction(node->rankingType), partitionFields, sortFields, sortOrders, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 6bfc625743dba..9acca25c30912 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -10500,6 +10500,44 @@ void from_json(const json& j, TopNNode& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + +// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays +static const std::pair RankingFunction_enum_table[] = + { // NOLINT: cert-err58-cpp + {RankingFunction::ROW_NUMBER, "ROW_NUMBER"}, + {RankingFunction::RANK, "RANK"}, + {RankingFunction::DENSE_RANK, "DENSE_RANK"}}; +void to_json(json& j, const RankingFunction& e) { + static_assert( + std::is_enum::value, "RankingFunction must be an enum!"); + const auto* it = std::find_if( + std::begin(RankingFunction_enum_table), + std::end(RankingFunction_enum_table), + [e](const std::pair& ej_pair) -> bool { + return ej_pair.first == e; + }); + j = ((it != std::end(RankingFunction_enum_table)) + ? it + : std::begin(RankingFunction_enum_table)) + ->second; +} +void from_json(const json& j, RankingFunction& e) { + static_assert( + std::is_enum::value, "RankingFunction must be an enum!"); + const auto* it = std::find_if( + std::begin(RankingFunction_enum_table), + std::end(RankingFunction_enum_table), + [&j](const std::pair& ej_pair) -> bool { + return ej_pair.second == j; + }); + e = ((it != std::end(RankingFunction_enum_table)) + ? it + : std::begin(RankingFunction_enum_table)) + ->first; +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { TopNRowNumberNode::TopNRowNumberNode() noexcept { _type = "com.facebook.presto.sql.planner.plan.TopNRowNumberNode"; } @@ -10516,6 +10554,13 @@ void to_json(json& j, const TopNRowNumberNode& p) { "TopNRowNumberNode", "Specification", "specification"); + to_json_key( + j, + "rankingType", + p.rankingType, + "TopNRowNumberNode", + "RankingFunction", + "rankingType"); to_json_key( j, "rowNumberVariable", @@ -10552,6 +10597,13 @@ void from_json(const json& j, TopNRowNumberNode& p) { "TopNRowNumberNode", "Specification", "specification"); + from_json_key( + j, + "rankingType", + p.rankingType, + "TopNRowNumberNode", + "RankingFunction", + "rankingType"); from_json_key( j, "rowNumberVariable", diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index cbce83539ca17..95642aab96c58 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -2485,9 +2485,15 @@ void to_json(json& j, const TopNNode& p); void from_json(const json& j, TopNNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +enum class RankingFunction { ROW_NUMBER, RANK, DENSE_RANK }; +extern void to_json(json& j, const RankingFunction& e); +extern void from_json(const json& j, RankingFunction& e); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct TopNRowNumberNode : public PlanNode { std::shared_ptr source = {}; Specification specification = {}; + RankingFunction rankingType = {}; VariableReferenceExpression rowNumberVariable = {}; int maxRowCountPerPartition = {}; bool partial = {}; diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java index 781b9088c519c..2ce0f42fd559b 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java @@ -13,6 +13,11 @@ */ package com.facebook.presto.nativeworker; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.sql.analyzer.FeaturesConfig; + +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.google.common.collect.ImmutableList; @@ -24,6 +29,7 @@ import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.*; public abstract class AbstractTestNativeWindowQueries extends AbstractTestQueryFramework @@ -32,6 +38,12 @@ protected enum FunctionType { RANK, VALUE, } + @Override + protected FeaturesConfig createFeaturesConfig() + { + return new FeaturesConfig().setNativeExecutionEnabled(true); + } + @Override protected void createTables() { @@ -179,6 +191,51 @@ public void testRowNumberWithFilter_2() assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey) rn, * from orders) WHERE rn = 1"); } + private static final PlanMatchPattern topNForFilter = anyTree( + anyNot(FilterNode.class, + node(TopNRowNumberNode.class, + anyTree( + tableScan("orders"))))); + + private static final PlanMatchPattern topNForLimit = anyTree( + limit(10, anyTree( + node(TopNRowNumberNode.class, + anyTree( + tableScan("orders")))))); + @Test + public void testTopNRowNumber() + { + String sql = "SELECT sum(rn) FROM (SELECT row_number() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders) WHERE rn <= 10"; + assertQuery(sql); + assertPlan(sql, topNForFilter); + + sql = "SELECT sum(rn) FROM (SELECT row_number() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders limit 10)"; + assertPlan(sql, topNForLimit); + } + + @Test + public void testTopNRank() + { + String sql = "SELECT sum(rn) FROM (SELECT rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders) WHERE rn <= 10"; + assertQuery(sql); + assertPlan(sql, topNForFilter); + + sql = "SELECT sum(rn) FROM (SELECT rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders limit 10)"; + assertPlan(sql, topNForFilter); + } + + @Test + public void testTopNDenseRank() + { + String sql = "SELECT sum(rn) FROM (SELECT dense_rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders) WHERE rn <= 10"; + assertQuery(sql); + assertPlan(sql, topNForFilter); + + sql = "SELECT sum(rn) FROM (SELECT dense_rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders limit 10)"; + assertQuery(sql); + assertPlan(sql, topNForLimit); + } + @Test public void testFirstValueOrderKey() { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java index 596b49015c0a8..26225810b1645 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java @@ -495,6 +495,7 @@ public static Optional> getExternalWorkerLaunc // Write config file String configProperties = format("discovery.uri=%s%n" + "presto.version=testversion%n" + + "native-execution-enabled=true" + "system-memory-gb=4%n" + "http-server.http.port=%d", discoveryUri, port); diff --git a/presto-native-execution/velox b/presto-native-execution/velox index bf3fba7654c44..ab57e639893de 160000 --- a/presto-native-execution/velox +++ b/presto-native-execution/velox @@ -1 +1 @@ -Subproject commit bf3fba7654c44dd8e9c6de14d2b2c9d1a1911d4c +Subproject commit ab57e639893de7e805b45b150eb40af1ec595e7e