Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat](nereids) support pull up predicate from set operator #39450

Merged
merged 24 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5d41847
[Feat](nereids) support pull up predicate from set operator
feiniaofeiafei Aug 15, 2024
da19001
[Feat](nereids) support pull up predicate from set operator
feiniaofeiafei Aug 15, 2024
d15f9f5
fix regression
feiniaofeiafei Aug 16, 2024
12ad38a
[Feat](nereids) support pull up predicate from set operator
feiniaofeiafei Aug 16, 2024
31b32c3
adjust order in rewrite and fix ut
feiniaofeiafei Aug 16, 2024
bb3b364
[Feat](nereids) support pull up predicate from set operator
feiniaofeiafei Aug 17, 2024
f5c84b9
[Feat](nereids) support pull up predicate from set operator
feiniaofeiafei Aug 18, 2024
91a93cc
fix regression
feiniaofeiafei Aug 19, 2024
4ad57a7
[Feat](nereids) support pull up predicate from set operator
feiniaofeiafei Aug 19, 2024
4859fed
add infer predicate of intersect and except
feiniaofeiafei Aug 21, 2024
497109e
add infer predicate for except and intersect
feiniaofeiafei Aug 21, 2024
78e9b58
use regularChildOutput in setOperation and modify tests
feiniaofeiafei Aug 22, 2024
638198c
remove third infer predicates
feiniaofeiafei Aug 26, 2024
bf6c399
fix regression
feiniaofeiafei Aug 26, 2024
2218435
fix regression
feiniaofeiafei Aug 26, 2024
39ee741
remove useless code
feiniaofeiafei Aug 27, 2024
9327850
fix regression
feiniaofeiafei Aug 27, 2024
ed7b314
fix regresssion
feiniaofeiafei Aug 27, 2024
b74a06a
fix regression
feiniaofeiafei Aug 27, 2024
5a5e271
add comments
feiniaofeiafei Aug 29, 2024
a5b7f1b
remove useless code
feiniaofeiafei Aug 29, 2024
1bd2435
support pull up from union with child and const exprs
feiniaofeiafei Aug 30, 2024
8c3e6df
fix ut
feiniaofeiafei Aug 30, 2024
13b6815
fix regression
feiniaofeiafei Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.QueryColumnCollector;
Expand Down Expand Up @@ -293,6 +292,21 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new ConvertInnerOrCrossJoin()),
topDown(new ProjectOtherJoinConditionForNestedLoopJoin())
),
topic("Set operation optimization",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments to explain why move this rewrite topic here

// Do MergeSetOperation first because we hope to match pattern of Distinct SetOperator.
topDown(new PushProjectThroughUnion(), new MergeProjects()),
bottomUp(new MergeSetOperations(), new MergeSetOperationsExcept()),
bottomUp(new PushProjectIntoOneRowRelation()),
topDown(new MergeOneRowRelationIntoUnion()),
costBased(topDown(new InferSetOperatorDistinct())),
topDown(new BuildAggForUnion()),
bottomUp(new EliminateEmptyRelation()),
morrySnow marked this conversation as resolved.
Show resolved Hide resolved
// when union has empty relation child and constantExprsList is not empty,
// after EliminateEmptyRelation, project can be pushed into union
topDown(new PushProjectIntoUnion())
),
// putting the "Column pruning and infer predicate" topic behind the "Set operation optimization"
// is because that pulling up predicates from union needs EliminateEmptyRelation in union child
topic("Column pruning and infer predicate",
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
custom(RuleType.INFER_PREDICATES, InferPredicates::new),
Expand All @@ -306,24 +320,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
// after eliminate outer join, we can move some filters to join.otherJoinConjuncts,
// this can help to translate plan to backend
topDown(new PushFilterInsideJoin()),
topDown(new FindHashConditionForJoin()),
topDown(new ExpressionNormalization())
morrySnow marked this conversation as resolved.
Show resolved Hide resolved
topDown(new FindHashConditionForJoin())
),

// this rule should invoke after ColumnPruning
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new),

topic("Set operation optimization",
// Do MergeSetOperation first because we hope to match pattern of Distinct SetOperator.
topDown(new PushProjectThroughUnion(), new MergeProjects()),
bottomUp(new MergeSetOperations(), new MergeSetOperationsExcept()),
bottomUp(new PushProjectIntoOneRowRelation()),
topDown(new MergeOneRowRelationIntoUnion()),
topDown(new PushProjectIntoUnion()),
costBased(topDown(new InferSetOperatorDistinct())),
topDown(new BuildAggForUnion())
),

topic("Eliminate GroupBy",
topDown(new EliminateGroupBy(),
new MergeAggregate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterScanRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectScanRule;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
Expand Down Expand Up @@ -87,6 +88,7 @@
import org.apache.doris.nereids.rules.implementation.LogicalWindowToPhysicalWindow;
import org.apache.doris.nereids.rules.rewrite.ConvertOuterJoinToAntiJoin;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.EliminateOuterJoin;
import org.apache.doris.nereids.rules.rewrite.MaxMinFilterPushDown;
import org.apache.doris.nereids.rules.rewrite.MergeFilters;
Expand Down Expand Up @@ -154,7 +156,12 @@ public class RuleSet {
new PushDownAliasThroughJoin(),
new PushDownFilterThroughWindow(),
new PushDownFilterThroughPartitionTopN(),
new ExpressionOptimization()
new ExpressionOptimization(),
// some useless predicates(e.g. 1=1) can be inferred by InferPredicates,
// the FoldConstantRule in ExpressionNormalization can fold 1=1 to true
// and EliminateFilter can eliminate the useless filter
new ExpressionNormalization(),
new EliminateFilter()
morrySnow marked this conversation as resolved.
Show resolved Hide resolved
);

public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.UnaryNode;
Expand Down Expand Up @@ -45,6 +46,9 @@
/**
* try to eliminate sub plan tree which contains EmptyRelation
*/
@DependsRules ({
BuildAggForUnion.class
})
public class EliminateEmptyRelation implements RewriteRuleFactory {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -107,6 +113,45 @@ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext
return filter;
}

@Override
public Plan visitLogicalExcept(LogicalExcept except, JobContext context) {
except = visitChildren(this, except, context);
Set<Expression> baseExpressions = pullUpPredicates(except);
if (baseExpressions.isEmpty()) {
return except;
}
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
builder.add(except.child(0));
for (int i = 1; i < except.arity(); ++i) {
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < except.getOutput().size(); ++j) {
NamedExpression output = except.getOutput().get(j);
replaceMap.put(output, except.getRegularChildOutput(i).get(j));
}
builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
}
return except.withChildren(builder.build());
}

@Override
public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context) {
intersect = visitChildren(this, intersect, context);
Set<Expression> baseExpressions = pullUpPredicates(intersect);
if (baseExpressions.isEmpty()) {
return intersect;
}
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
for (int i = 0; i < intersect.arity(); ++i) {
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < intersect.getOutput().size(); ++j) {
NamedExpression output = intersect.getOutput().get(j);
replaceMap.put(output, intersect.getRegularChildOutput(i).get(j));
}
builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
}
return intersect.withChildren(builder.build());
}

private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand All @@ -38,6 +45,8 @@
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -60,6 +69,78 @@ public ImmutableSet<Expression> visit(Plan plan, Void context) {
return ImmutableSet.of();
}

@Override
public ImmutableSet<Expression> visitLogicalOneRowRelation(LogicalOneRowRelation r, Void context) {
ImmutableSet.Builder<Expression> predicates = ImmutableSet.builder();
for (NamedExpression expr : r.getProjects()) {
if (expr instanceof Alias && expr.child(0) instanceof Literal) {
predicates.add(new EqualTo(expr.toSlot(), expr.child(0)));
}
}
return predicates.build();
}

@Override
public ImmutableSet<Expression> visitLogicalIntersect(LogicalIntersect intersect, Void context) {
return cacheOrElse(intersect, () -> {
ImmutableSet.Builder<Expression> builder = ImmutableSet.builder();
for (int i = 0; i < intersect.children().size(); ++i) {
Plan child = intersect.child(i);
Set<Expression> childFilters = child.accept(this, context);
if (childFilters.isEmpty()) {
continue;
}
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < intersect.getOutput().size(); ++j) {
NamedExpression output = intersect.getOutput().get(j);
replaceMap.put(intersect.getRegularChildOutput(i).get(j), output);
}
builder.addAll(ExpressionUtils.replace(childFilters, replaceMap));
}
return getAvailableExpressions(builder.build(), intersect);
});
}

@Override
public ImmutableSet<Expression> visitLogicalExcept(LogicalExcept except, Void context) {
return cacheOrElse(except, () -> {
if (except.arity() < 1) {
return ImmutableSet.of();
}
Set<Expression> firstChildFilters = except.child(0).accept(this, context);
if (firstChildFilters.isEmpty()) {
return ImmutableSet.of();
}
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int i = 0; i < except.getOutput().size(); ++i) {
NamedExpression output = except.getOutput().get(i);
replaceMap.put(except.getRegularChildOutput(0).get(i), output);
}
return ImmutableSet.copyOf(ExpressionUtils.replace(firstChildFilters, replaceMap));
});
}

@Override
public ImmutableSet<Expression> visitLogicalUnion(LogicalUnion union, Void context) {
return cacheOrElse(union, () -> {
if (!union.getConstantExprsList().isEmpty() && union.arity() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why could not process union with both constantExprs and normal children? how about

union
|-- const(1 as a)
+-- filter(a = 1)
    +-- project(a)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return getFiltersFromUnionConstExprs(union);
} else if (union.getConstantExprsList().isEmpty() && union.arity() != 0) {
return getFiltersFromUnionChild(union, context);
} else if (!union.getConstantExprsList().isEmpty() && union.arity() != 0) {
HashSet<Expression> fromChildFilters = new HashSet<>(getFiltersFromUnionChild(union, context));
if (fromChildFilters.isEmpty()) {
return ImmutableSet.of();
}
if (!ExpressionUtils.unionConstExprsSatisfyConjuncts(union, fromChildFilters)) {
return ImmutableSet.of();
}
return ImmutableSet.copyOf(fromChildFilters);
}
return ImmutableSet.of();
});
}

@Override
public ImmutableSet<Expression> visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
return cacheOrElse(filter, () -> {
Expand All @@ -77,6 +158,10 @@ public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? e
ImmutableSet<Expression> rightPredicates = join.right().accept(this, context);
predicates.addAll(leftPredicates);
predicates.addAll(rightPredicates);
if (join.getJoinType() == JoinType.CROSS_JOIN || join.getJoinType() == JoinType.INNER_JOIN) {
predicates.addAll(join.getHashJoinConjuncts());
predicates.addAll(join.getOtherJoinConjuncts());
}
return getAvailableExpressions(predicates, join);
});
}
Expand Down Expand Up @@ -138,6 +223,9 @@ private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Ex
}

private ImmutableSet<Expression> getAvailableExpressions(Set<Expression> predicates, Plan plan) {
if (predicates.isEmpty()) {
return ImmutableSet.of();
}
Set<Expression> inferPredicates = PredicatePropagation.infer(predicates);
Builder<Expression> newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size() + 10);
Set<Slot> outputSet = plan.getOutputSet();
Expand All @@ -159,4 +247,55 @@ private ImmutableSet<Expression> getAvailableExpressions(Set<Expression> predica
private boolean hasAgg(Expression expression) {
return expression.anyMatch(AggregateFunction.class::isInstance);
}

private ImmutableSet<Expression> getFiltersFromUnionChild(LogicalUnion union, Void context) {
Set<Expression> filters = new HashSet<>();
for (int i = 0; i < union.getArity(); ++i) {
Plan child = union.child(i);
Set<Expression> childFilters = child.accept(this, context);
if (childFilters.isEmpty()) {
return ImmutableSet.of();
}
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < union.getOutput().size(); ++j) {
NamedExpression output = union.getOutput().get(j);
replaceMap.put(union.getRegularChildOutput(i).get(j), output);
}
Set<Expression> unionFilters = ExpressionUtils.replace(childFilters, replaceMap);
if (0 == i) {
filters.addAll(unionFilters);
} else {
filters.retainAll(unionFilters);
}
if (filters.isEmpty()) {
return ImmutableSet.of();
}
}
return ImmutableSet.copyOf(filters);
}

private ImmutableSet<Expression> getFiltersFromUnionConstExprs(LogicalUnion union) {
List<List<NamedExpression>> constExprs = union.getConstantExprsList();
ImmutableSet.Builder<Expression> filtersFromConstExprs = ImmutableSet.builder();
for (int col = 0; col < union.getOutput().size(); ++col) {
Expression compareExpr = union.getOutput().get(col);
Set<Expression> options = new HashSet<>();
for (List<NamedExpression> constExpr : constExprs) {
if (constExpr.get(col) instanceof Alias
&& ((Alias) constExpr.get(col)).child() instanceof Literal) {
options.add(((Alias) constExpr.get(col)).child());
} else {
options.clear();
break;
}
}
options.removeIf(option -> option instanceof NullLiteral);
if (options.size() > 1) {
filtersFromConstExprs.add(new InPredicate(compareExpr, options));
} else if (options.size() == 1) {
filtersFromConstExprs.add(new EqualTo(compareExpr, options.iterator().next()));
}
}
return filtersFromConstExprs.build();
}
}
Loading
Loading