Skip to content

Commit

Permalink
[native] TopNRank optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
aditi-pandit committed Nov 25, 2024
1 parent 461ae13 commit 99600e4
Show file tree
Hide file tree
Showing 16 changed files with 226 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ public Optional<PlanNode> visitTopNRowNumber(TopNRowNumberNode node, Context con
new WindowNode.Specification(
partitionBy,
node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))),
node.getRankingFunction(),
rowNumberVariable,
node.getMaxRowCountPerPartition(),
node.isPartial(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ public PlanOptimizers(
estimatedExchangesCostCalculator,
ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionAndTypeManager()))),
new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown
new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits
new WindowFilterPushDown(metadata, featuresConfig.isNativeExecutionEnabled()), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits
prefilterForLimitingAggregation,
new IterativeOptimizer(
metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr
idAllocator.getNextId(),
child.getNode(),
node.getSpecification(),
node.getRankingFunction(),
node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(),
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa
node.getId(),
child.getNode(),
node.getSpecification(),
node.getRankingFunction(),
node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(),
node.isPartial(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ public Optional<DecorrelationResult> visitTopN(TopNNode node, Void context)
new Specification(
ImmutableList.copyOf(childDecorrelationResult.variablesToPropagate),
Optional.of(orderingScheme)),
TopNRowNumberNode.RankingFunction.ROW_NUMBER,
variableAllocator.newVariable("row_number", BIGINT),
toIntExact(node.getCount()),
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Set<Va
node.getStatsEquivalentPlanNode(),
source,
node.getSpecification(),
node.getRankingFunction(),
node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(),
node.isPartial(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Void>
node.getId(),
context.rewrite(node.getSource()),
canonicalizeAndDistinct(node.getSpecification()),
node.getRankingFunction(),
canonicalize(node.getRowNumberVariable()),
node.getMaxRowCountPerPartition(),
node.isPartial(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.Optional;
import java.util.OptionalInt;

import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled;
import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRowNumber;
import static com.facebook.presto.common.predicate.Marker.Bound.BELOW;
import static com.facebook.presto.common.type.BigintType.BIGINT;
Expand All @@ -65,10 +66,13 @@ public class WindowFilterPushDown
private final RowExpressionDomainTranslator domainTranslator;
private final LogicalRowExpressions logicalRowExpressions;

public WindowFilterPushDown(Metadata metadata)
private boolean isNativeExecution = false;

public WindowFilterPushDown(Metadata metadata, boolean isNativeExecution)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.domainTranslator = new RowExpressionDomainTranslator(metadata);
this.isNativeExecution = isNativeExecution;
this.logicalRowExpressions = new LogicalRowExpressions(
new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()),
new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()),
Expand All @@ -84,7 +88,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider
requireNonNull(variableAllocator, "variableAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");

Rewriter rewriter = new Rewriter(idAllocator, metadata, domainTranslator, logicalRowExpressions, session);
Rewriter rewriter = new Rewriter(idAllocator, metadata, domainTranslator, logicalRowExpressions, session, isNativeExecution);
PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
}
Expand All @@ -97,15 +101,17 @@ private static class Rewriter
private final RowExpressionDomainTranslator domainTranslator;
private final LogicalRowExpressions logicalRowExpressions;
private final Session session;
private final boolean isNativeExecution;
private boolean planChanged;

private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, RowExpressionDomainTranslator domainTranslator, LogicalRowExpressions logicalRowExpressions, Session session)
private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, RowExpressionDomainTranslator domainTranslator, LogicalRowExpressions logicalRowExpressions, Session session, boolean isNativeExecution)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
this.logicalRowExpressions = logicalRowExpressions;
this.session = requireNonNull(session, "session is null");
this.isNativeExecution = isNativeExecution;
}

public boolean isPlanChanged()
Expand Down Expand Up @@ -138,6 +144,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
{
// Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value
// TODO (Aditi) : Don't think this check is needed for Native engine.
if (node.getCount() > Integer.MAX_VALUE) {
return context.defaultRewrite(node);
}
Expand All @@ -152,11 +159,11 @@ public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
planChanged = true;
source = rowNumberNode;
}
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager(), isNativeExecution) && isOptimizeTopNRowNumber(session)) {
WindowNode windowNode = (WindowNode) source;
// verify that unordered row_number window functions are replaced by RowNumberNode
verify(windowNode.getOrderingScheme().isPresent());
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit, metadata.getFunctionAndTypeManager());
if (windowNode.getPartitionBy().isEmpty()) {
return topNRowNumberNode;
}
Expand All @@ -183,13 +190,13 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
}
}
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) {
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager(), isNativeExecution) && isOptimizeTopNRowNumber(session)) {
WindowNode windowNode = (WindowNode) source;
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable());
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable);

if (upperBound.isPresent()) {
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt(), metadata.getFunctionAndTypeManager());
planChanged = true;
return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt());
}
Expand Down Expand Up @@ -273,13 +280,23 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa
return new RowNumberNode(node.getSourceLocation(), node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberVariable(), Optional.of(newRowCountPerPartition), false, node.getHashVariable());
}

private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit)
private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit, FunctionAndTypeManager functionAndTypeManager)
{
VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getWindowFunctions().keySet());
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(windowNode.getWindowFunctions().get(rowNumberVariable).getFunctionHandle());

TopNRowNumberNode.RankingFunction rankingFunction =
isRowNumberMetadata(functionAndTypeManager, functionMetadata) ?
TopNRowNumberNode.RankingFunction.ROW_NUMBER :
isRankMetadata(functionAndTypeManager, functionMetadata) ?
TopNRowNumberNode.RankingFunction.RANK :
TopNRowNumberNode.RankingFunction.DENSE_RANK;
return new TopNRowNumberNode(
windowNode.getSourceLocation(),
idAllocator.getNextId(),
windowNode.getSource(),
windowNode.getSpecification(),
rankingFunction,
getOnlyElement(windowNode.getCreatedVariable()),
limit,
false,
Expand All @@ -288,22 +305,49 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi

private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
{
return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent();
if (node.getWindowFunctions().size() != 1) {
return false;
}
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());

return isRowNumberMetadata(functionAndTypeManager,
functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()))
&& !node.getOrderingScheme().isPresent();
}

private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager)
private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager, boolean isNativeExecution)
{
if (node.getWindowFunctions().size() != 1) {
return false;
}
VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet());
return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle()));
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle());
if (isNativeExecution) {
return isRowNumberMetadata(functionAndTypeManager, functionMetadata)
|| node.getOrderingScheme().isPresent() && (isRankMetadata(functionAndTypeManager, functionMetadata)
|| isDenseRankMetadata(functionAndTypeManager, functionMetadata));
}

return isRowNumberMetadata(functionAndTypeManager, functionMetadata);
}


private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
{
FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of());
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction));
}

private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
{
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of());
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction));
}

private static boolean isDenseRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata)
{
FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of());
return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,18 @@
public final class TopNRowNumberNode
extends InternalPlanNode
{
public enum RankingFunction
{
ROW_NUMBER,
RANK,
DENSE_RANK
}

private final PlanNode source;
private final Specification specification;
private final RankingFunction rankingFunction;
private final VariableReferenceExpression rowNumberVariable;
private final int maxRowCountPerPartition;
private final int maxRankPerPartition;
private final boolean partial;
private final Optional<VariableReferenceExpression> hashVariable;

Expand All @@ -49,12 +57,13 @@ public TopNRowNumberNode(
@JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source,
@JsonProperty("specification") Specification specification,
@JsonProperty("rankingType") RankingFunction rankingFunction,
@JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable,
@JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition,
@JsonProperty("maxRowCountPerPartition") int maxRankPerPartition,
@JsonProperty("partial") boolean partial,
@JsonProperty("hashVariable") Optional<VariableReferenceExpression> hashVariable)
{
this(sourceLocation, id, Optional.empty(), source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable);
this(sourceLocation, id, Optional.empty(), source, specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable);
}

public TopNRowNumberNode(
Expand All @@ -63,8 +72,9 @@ public TopNRowNumberNode(
Optional<PlanNode> statsEquivalentPlanNode,
PlanNode source,
Specification specification,
RankingFunction rankingFunction,
VariableReferenceExpression rowNumberVariable,
int maxRowCountPerPartition,
int maxRankPerPartition,
boolean partial,
Optional<VariableReferenceExpression> hashVariable)
{
Expand All @@ -74,13 +84,14 @@ public TopNRowNumberNode(
requireNonNull(specification, "specification is null");
checkArgument(specification.getOrderingScheme().isPresent(), "specification orderingScheme is absent");
requireNonNull(rowNumberVariable, "rowNumberVariable is null");
checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0");
checkArgument(maxRankPerPartition > 0, "maxRowCountPerPartition must be > 0");
requireNonNull(hashVariable, "hashVariable is null");

this.source = source;
this.specification = specification;
this.rankingFunction = rankingFunction;
this.rowNumberVariable = rowNumberVariable;
this.maxRowCountPerPartition = maxRowCountPerPartition;
this.maxRankPerPartition = maxRankPerPartition;
this.partial = partial;
this.hashVariable = hashVariable;
}
Expand Down Expand Up @@ -108,6 +119,12 @@ public PlanNode getSource()
return source;
}

@JsonProperty
public RankingFunction getRankingFunction()
{
return rankingFunction;
}

@JsonProperty
public Specification getSpecification()
{
Expand All @@ -133,7 +150,7 @@ public VariableReferenceExpression getRowNumberVariable()
@JsonProperty
public int getMaxRowCountPerPartition()
{
return maxRowCountPerPartition;
return maxRankPerPartition;
}

@JsonProperty
Expand All @@ -157,12 +174,12 @@ public <R, C> R accept(InternalPlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable);
return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable);
}

@Override
public PlanNode assignStatsEquivalentPlanNode(Optional<PlanNode> statsEquivalentPlanNode)
{
return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable);
return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rankingFunction, rowNumberVariable, maxRankPerPartition, partial, hashVariable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context)
{
printNode(node,
"TopNRowNumber",
format("partition by = %s|order by = %s|n = %s",
format("function = %s; partition by = %s|order by = %s|n = %s",
node.getRankingFunction(),
Joiner.on(", ").join(node.getPartitionBy()),
Joiner.on(", ").join(node.getOrderingScheme().getOrderByVariables()), node.getMaxRowCountPerPartition()),
NODE_COLORS.get(NodeType.WINDOW));
Expand Down
Loading

0 comments on commit 99600e4

Please sign in to comment.