Skip to content

Commit

Permalink
[fix](nerieds) avoid redundant enumeration same LogicalProject in memo (
Browse files Browse the repository at this point in the history
apache#38317)

1. use set to compare project
2. use map to store enforcer
3. avoid genarate useless project under bottom join when do join reorder
  • Loading branch information
morrySnow authored and gavinchou committed Aug 4, 2024
1 parent 9a216fe commit cdaabef
Show file tree
Hide file tree
Showing 23 changed files with 230 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physical
GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow(
() -> new AnalysisException("lowestCostPlans with physicalProperties("
+ physicalProperties + ") doesn't exist in root group")).second;
if (rootGroup.getEnforcers().contains(groupExpression)) {
if (rootGroup.getEnforcers().containsKey(groupExpression)) {
rootGroup.addChosenEnforcerId(groupExpression.getId().asInt());
rootGroup.addChosenEnforcerProperties(physicalProperties);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ public final void execute() throws AnalysisException {
GroupExpressionMatching groupExpressionMatching
= new GroupExpressionMatching(rule.getPattern(), groupExpression);
for (Plan plan : groupExpressionMatching) {
if (rule.isExploration()
&& context.getCascadesContext().getMemo().getGroupExpressionsSize() > context.getCascadesContext()
.getConnectContext().getSessionVariable().memoMaxGroupExpressionSize) {
break;
}
List<Plan> newPlans = rule.transform(plan, context.getCascadesContext());
for (Plan newPlan : newPlans) {
if (newPlan == plan) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class Group {

private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
private final List<GroupExpression> physicalExpressions = Lists.newArrayList();
private final List<GroupExpression> enforcers = Lists.newArrayList();
private final Map<GroupExpression, GroupExpression> enforcers = Maps.newHashMap();
private boolean isStatsReliable = true;
private LogicalProperties logicalProperties;

Expand Down Expand Up @@ -239,10 +239,10 @@ public GroupExpression getBestPlan(PhysicalProperties properties) {

public void addEnforcer(GroupExpression enforcer) {
enforcer.setOwnerGroup(this);
enforcers.add(enforcer);
enforcers.put(enforcer, enforcer);
}

public List<GroupExpression> getEnforcers() {
public Map<GroupExpression, GroupExpression> getEnforcers() {
return enforcers;
}

Expand Down Expand Up @@ -346,9 +346,9 @@ public void mergeTo(Group target) {
parentExpressions.keySet().forEach(parent -> target.addParentExpression(parent));

// move enforcers Ownership
enforcers.forEach(ge -> ge.children().set(0, target));
enforcers.forEach((k, v) -> k.children().set(0, target));
// TODO: dedup?
enforcers.forEach(enforcer -> target.addEnforcer(enforcer));
enforcers.forEach((k, v) -> target.addEnforcer(k));
enforcers.clear();

// move LogicalExpression PhysicalExpression Ownership
Expand Down Expand Up @@ -458,7 +458,7 @@ public String toString() {
str.append(" ").append(physicalExpression).append("\n");
}
str.append(" enforcers:\n");
for (GroupExpression enforcer : enforcers) {
for (GroupExpression enforcer : enforcers.keySet()) {
str.append(" ").append(enforcer).append("\n");
}
if (!chosenEnforcerIdList.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,7 @@ public void mergeGroup(Group source, Group destination, HashMap<Long, Group> pla
return;
}
Group parentOwnerGroup = srcParent.getOwnerGroup();
HashSet<GroupExpression> enforcers = new HashSet<>(parentOwnerGroup.getEnforcers());
if (enforcers.contains(srcParent)) {
if (parentOwnerGroup.getEnforcers().containsKey(srcParent)) {
continue;
}
needReplaceChild.add(srcParent);
Expand Down Expand Up @@ -946,7 +945,7 @@ private List<GroupExpression> extractGroupExpressionSatisfyProp(Group group, Phy
List<GroupExpression> exprs = Lists.newArrayList(bestExpr);
Set<GroupExpression> hasVisited = new HashSet<>();
hasVisited.add(bestExpr);
Stream.concat(group.getPhysicalExpressions().stream(), group.getEnforcers().stream())
Stream.concat(group.getPhysicalExpressions().stream(), group.getEnforcers().keySet().stream())
.forEach(groupExpression -> {
if (!groupExpression.getInputPropertiesListOrEmpty(prop).isEmpty()
&& !groupExpression.equals(bestExpr) && !hasVisited.contains(groupExpression)) {
Expand All @@ -969,7 +968,7 @@ private List<List<PhysicalProperties>> extractInputProperties(GroupExpression gr
res.add(groupExpression.getInputPropertiesList(prop));

// return optimized input for enforcer
if (groupExpression.getOwnerGroup().getEnforcers().contains(groupExpression)) {
if (groupExpression.getOwnerGroup().getEnforcers().containsKey(groupExpression)) {
return res;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public boolean isRewrite() {
return ruleType.getRuleTypeClass() == RuleTypeClass.REWRITE;
}

public boolean isExploration() {
return ruleType.getRuleTypeClass() == RuleTypeClass.EXPLORATION;
}

@Override
public String toString() {
return getRuleType().toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,29 +381,29 @@ public enum RuleType {
EAGER_SPLIT(RuleTypeClass.EXPLORATION),

EXPLORATION_SENTINEL(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_JOIN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_JOIN(RuleTypeClass.EXPLORATION),

MATERIALIZED_VIEW_PROJECT_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_AGGREGATE(RuleTypeClass.EXPLORATION),

MATERIALIZED_VIEW_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.EXPLORATION),

MATERIALIZED_VIEW_FILTER_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_FILTER_PROJECT_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_FILTER_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_ONLY_SCAN(RuleTypeClass.EXPLORATION),
MATERIALIZED_VIEW_PROJECT_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_JOIN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_JOIN(RuleTypeClass.MATERIALIZE_VIEW),

MATERIALIZED_VIEW_PROJECT_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),

MATERIALIZED_VIEW_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_AGGREGATE_ON_NONE_AGGREGATE(RuleTypeClass.MATERIALIZE_VIEW),

MATERIALIZED_VIEW_FILTER_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_FILTER_PROJECT_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_PROJECT_FILTER_SCAN(RuleTypeClass.MATERIALIZE_VIEW),
MATERIALIZED_VIEW_ONLY_SCAN(RuleTypeClass.MATERIALIZE_VIEW),

// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
Expand Down Expand Up @@ -491,6 +491,7 @@ public <INPUT_TYPE extends Plan, OUTPUT_TYPE extends Plan> Rule build(
enum RuleTypeClass {
REWRITE,
EXPLORATION,
MATERIALIZE_VIEW,
// This type is used for unit test only.
CHECK,
IMPLEMENTATION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends Plan, ? ex
.collect(Collectors.toSet());
}

public static Plan newProjectIfNeeded(Set<ExprId> requiredExprIds, Plan plan) {
if (requiredExprIds.equals(plan.getOutputExprIdSet())) {
return plan;
}
return newProject(requiredExprIds, plan);
}

public static Plan newProject(Set<ExprId> requiredExprIds, Plan plan) {
List<NamedExpression> projects = plan.getOutput().stream()
.filter(namedExpr -> requiredExprIds.contains(namedExpr.getExprId()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public Rule build() {
newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, b);
Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b);

LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(newTopHashConjuncts,
newTopOtherConjuncts, left, right, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Rule build() {
newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, c);
Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, c);

LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(
newTopHashConjuncts, newTopOtherConjuncts, left, right, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public Rule build() {
topProject.getProjects().forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, a);
Plan left = CBOUtils.newProjectIfNeeded(topUsedExprIds, a);
Plan right = CBOUtils.newProject(topUsedExprIds, newBottomJoin);

LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public List<Rule> buildRules() {
.forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan newBottomJoin = topJoin.withChildrenNoContext(a, c, null);
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, b);
Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b);

Plan newTopJoin = bottomJoin.withChildrenNoContext(left, right, null);
return topProject.withChildren(newTopJoin);
Expand Down Expand Up @@ -102,7 +102,7 @@ public List<Rule> buildRules() {
.forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan newBottomJoin = topJoin.withChildrenNoContext(a, b, null);
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, c);
Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, c);

Plan newTopJoin = bottomJoin.withChildrenNoContext(left, right, null);
return topProject.withChildren(newTopJoin);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public Rule build() {
topProject.getProjects().forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds()));
bottomJoin.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
bottomJoin.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, a);
Plan left = CBOUtils.newProjectIfNeeded(topUsedExprIds, a);
Plan right = CBOUtils.newProject(topUsedExprIds, newBottomJoin);

LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public Rule build() {
bottomJoin.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
bottomJoin.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, b);
Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b);

LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right, null);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public Rule build() {
bottomSemi.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));

Plan left = CBOUtils.newProject(topUsedExprIds, newBottomSemi);
Plan right = CBOUtils.newProject(topUsedExprIds, b);
Plan right = CBOUtils.newProjectIfNeeded(topUsedExprIds, b);

LogicalJoin newTopSemi = bottomSemi.withChildrenNoContext(left, right, null);
newTopSemi.getJoinReorderContext().copyFrom(topSemi.getJoinReorderContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.apache.doris.nereids.util.Utils;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
Expand All @@ -47,6 +49,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
* Logical project plan.
Expand All @@ -55,6 +58,7 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
implements Project, OutputPrunable {

private final List<NamedExpression> projects;
private final Supplier<Set<NamedExpression>> projectsSet;
private final List<NamedExpression> excepts;
private final boolean isDistinct;

Expand Down Expand Up @@ -84,6 +88,7 @@ private LogicalProject(List<NamedExpression> projects, List<NamedExpression> exc
this.projects = projects.isEmpty()
? ImmutableList.of(ExpressionUtils.selectMinimumColumn(child.get(0).getOutput()))
: projects;
this.projectsSet = Suppliers.memoize(() -> ImmutableSet.copyOf(this.projects));
this.excepts = Utils.fastToImmutableList(excepts);
this.isDistinct = isDistinct;
}
Expand Down Expand Up @@ -139,7 +144,7 @@ public boolean equals(Object o) {
return false;
}
LogicalProject<?> that = (LogicalProject<?>) o;
boolean equal = projects.equals(that.projects)
boolean equal = projectsSet.get().equals(that.projectsSet.get())
&& excepts.equals(that.excepts)
&& isDistinct == that.isDistinct;
// TODO: should add exprId for UnBoundStar and BoundStar for equality comparison
Expand All @@ -151,7 +156,7 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(projects);
return Objects.hash(projectsSet.get());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,24 @@
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
* Physical project plan.
*/
public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project {

private final List<NamedExpression> projects;
private final Supplier<Set<NamedExpression>> projectsSet;
//multiLayerProjects is used to extract common expressions
// projects: (A+B) * 2, (A+B) * 3
// multiLayerProjects:
Expand All @@ -62,6 +67,7 @@ public PhysicalProject(List<NamedExpression> projects, Optional<GroupExpression>
LogicalProperties logicalProperties, CHILD_TYPE child) {
super(PlanType.PHYSICAL_PROJECT, groupExpression, logicalProperties, child);
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));
this.projectsSet = Suppliers.memoize(() -> ImmutableSet.copyOf(this.projects));
}

public PhysicalProject(List<NamedExpression> projects, Optional<GroupExpression> groupExpression,
Expand All @@ -70,6 +76,7 @@ public PhysicalProject(List<NamedExpression> projects, Optional<GroupExpression>
super(PlanType.PHYSICAL_PROJECT, groupExpression, logicalProperties, physicalProperties, statistics,
child);
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));
this.projectsSet = Suppliers.memoize(() -> ImmutableSet.copyOf(this.projects));
}

public List<NamedExpression> getProjects() {
Expand All @@ -96,13 +103,13 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
PhysicalProject that = (PhysicalProject) o;
return projects.equals(that.projects);
PhysicalProject<?> that = (PhysicalProject<?>) o;
return projectsSet.get().equals(that.projectsSet.get());
}

@Override
public int hashCode() {
return Objects.hash(projects);
return Objects.hash(projectsSet.get());
}

@Override
Expand Down
Loading

0 comments on commit cdaabef

Please sign in to comment.