Skip to content

Commit

Permalink
[refactor](Nereids) refactor push down element_at on variant
Browse files Browse the repository at this point in the history
intro a new rule VARIANT_SUB_PATH_PRUNING to prune variant sub path
  • Loading branch information
morrySnow committed Jun 19, 2024
1 parent fdb5891 commit d5617fe
Show file tree
Hide file tree
Showing 36 changed files with 1,118 additions and 733 deletions.
22 changes: 4 additions & 18 deletions fe/.idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
Expand Down Expand Up @@ -64,7 +63,6 @@
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -131,15 +129,6 @@ public class StatementContext implements Closeable {
private final List<Expression> joinFilters = new ArrayList<>();

private final List<Hint> hints = new ArrayList<>();
// Root Slot -> Paths -> Sub-column Slots
private final Map<Slot, Map<List<String>, SlotReference>> subColumnSlotRefMap
= Maps.newHashMap();

// Map from rewritten slot to original expr
private final Map<Slot, Expression> subColumnOriginalExprMap = Maps.newHashMap();

// Map from original expr to rewritten slot
private final Map<Expression, Slot> originalExprToRewrittenSubColumn = Maps.newHashMap();

// Map slot to its relation, currently used in SlotReference to find its original
// Relation for example LogicalOlapScan
Expand Down Expand Up @@ -265,58 +254,10 @@ public Optional<SqlCacheContext> getSqlCacheContext() {
return Optional.ofNullable(sqlCacheContext);
}

public Set<SlotReference> getAllPathsSlots() {
Set<SlotReference> allSlotReferences = Sets.newHashSet();
for (Map<List<String>, SlotReference> slotReferenceMap : subColumnSlotRefMap.values()) {
allSlotReferences.addAll(slotReferenceMap.values());
}
return allSlotReferences;
}

public Expression getOriginalExpr(SlotReference rewriteSlot) {
return subColumnOriginalExprMap.getOrDefault(rewriteSlot, null);
}

public Slot getRewrittenSlotRefByOriginalExpr(Expression originalExpr) {
return originalExprToRewrittenSubColumn.getOrDefault(originalExpr, null);
}

/**
* Add a slot ref attached with paths in context to avoid duplicated slot
*/
public void addPathSlotRef(Slot root, List<String> paths, SlotReference slotRef, Expression originalExpr) {
subColumnSlotRefMap.computeIfAbsent(root, k -> Maps.newTreeMap((lst1, lst2) -> {
Iterator<String> it1 = lst1.iterator();
Iterator<String> it2 = lst2.iterator();
while (it1.hasNext() && it2.hasNext()) {
int result = it1.next().compareTo(it2.next());
if (result != 0) {
return result;
}
}
return Integer.compare(lst1.size(), lst2.size());
}));
subColumnSlotRefMap.get(root).put(paths, slotRef);
subColumnOriginalExprMap.put(slotRef, originalExpr);
originalExprToRewrittenSubColumn.put(originalExpr, slotRef);
}

public SlotReference getPathSlot(Slot root, List<String> paths) {
Map<List<String>, SlotReference> pathsSlotsMap = subColumnSlotRefMap.getOrDefault(root, null);
if (pathsSlotsMap == null) {
return null;
}
return pathsSlotsMap.getOrDefault(paths, null);
}

public void addSlotToRelation(Slot slot, Relation relation) {
slotToRelation.put(slot, relation);
}

public Relation getRelationBySlot(Slot slot) {
return slotToRelation.getOrDefault(slot, null);
}

public boolean isDpHyp() {
return isDpHyp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HighOrderFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.PushDownToProjectionFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf;
Expand All @@ -102,7 +101,6 @@
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TFunctionBinaryType;

import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -211,20 +209,6 @@ private OlapTable getOlapTableDirectly(SlotRef left) {

@Override
public Expr visitElementAt(ElementAt elementAt, PlanTranslatorContext context) {
if (PushDownToProjectionFunction.validToPushDown(elementAt)) {
if (ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable() != null
&& !ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot()) {
throw new AnalysisException(
"set enable_rewrite_element_at_to_slot=true when using element_at function for variant type");
}
SlotReference rewrittenSlot = (SlotReference) context.getConnectContext()
.getStatementContext().getRewrittenSlotRefByOriginalExpr(elementAt);
// rewrittenSlot == null means variant is not from table. so keep element_at function
if (rewrittenSlot != null) {
return context.findSlotRef(rewrittenSlot.getExprId());
}
}
return visitScalarFunction(elementAt, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.scalar.PushDownToProjectionFunction;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.AggPhase;
Expand Down Expand Up @@ -1250,8 +1249,7 @@ public PlanFragment visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, P
}
if (planNode instanceof ExchangeNode || planNode instanceof SortNode || planNode instanceof UnionNode
// this means we have filter->limit->project, need a SelectNode
|| (child instanceof PhysicalProject
&& !((PhysicalProject<?>) child).hasPushedDownToProjectionFunctions())) {
|| child instanceof PhysicalProject) {
// the three nodes don't support conjuncts, need create a SelectNode to filter data
SelectNode selectNode = new SelectNode(context.nextPlanNodeId(), planNode);
selectNode.setNereidsId(filter.getId());
Expand Down Expand Up @@ -1833,35 +1831,6 @@ && findOlapScanNodesByPassExchangeAndJoinNode(inputFragment.getPlanRoot())) {
return inputFragment;
}

// collect all valid PushDownToProjectionFunction from expression
private List<Expression> getPushDownToProjectionFunctionForRewritten(NamedExpression expression) {
List<Expression> targetExprList = expression.collectToList(PushDownToProjectionFunction.class::isInstance);
return targetExprList.stream()
.filter(PushDownToProjectionFunction::validToPushDown)
.collect(Collectors.toList());
}

// register rewritten slots from original PushDownToProjectionFunction
private void registerRewrittenSlot(PhysicalProject<? extends Plan> project, OlapScanNode olapScanNode) {
// register slots that are rewritten from element_at/etc..
List<Expression> allPushDownProjectionFunctions = project.getProjects().stream()
.map(this::getPushDownToProjectionFunctionForRewritten)
.flatMap(List::stream)
.collect(Collectors.toList());
for (Expression expr : allPushDownProjectionFunctions) {
PushDownToProjectionFunction function = (PushDownToProjectionFunction) expr;
if (context != null
&& context.getConnectContext() != null
&& context.getConnectContext().getStatementContext() != null) {
Slot argumentSlot = function.getInputSlots().stream().findFirst().get();
Expression rewrittenSlot = PushDownToProjectionFunction.rewriteToSlot(
function, (SlotReference) argumentSlot);
TupleDescriptor tupleDescriptor = context.getTupleDesc(olapScanNode.getTupleId());
context.createSlotDesc(tupleDescriptor, (SlotReference) rewrittenSlot);
}
}
}

// TODO: generate expression mapping when be project could do in ExecNode.
@Override
public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project, PlanTranslatorContext context) {
Expand All @@ -1876,12 +1845,6 @@ public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project

PlanFragment inputFragment = project.child(0).accept(this, context);

if (inputFragment.getPlanRoot() instanceof OlapScanNode) {
// function already pushed down in projection
// e.g. select count(distinct cast(element_at(v, 'a') as int)) from tbl;
registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot());
}

PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Expr> projectionExprs = null;
List<Expr> allProjectionExprs = Lists.newArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,12 @@ public SlotDescriptor createSlotDesc(TupleDescriptor tupleDesc, SlotReference sl
slotDescriptor.setLabel(slotReference.getName());
} else {
slotRef = new SlotRef(slotDescriptor);
if (slotReference.hasSubColPath()) {
slotDescriptor.setSubColLables(slotReference.getSubColPath());
if (slotReference.hasSubColPath() && slotReference.getColumn().isPresent()) {
slotDescriptor.setSubColLables(slotReference.getSubPath());
// use lower case name for variant's root, since backend treat parent column as lower case
// see issue: https://github.com/apache/doris/pull/32999/commits
slotDescriptor.setMaterializedColumnName(slotRef.getColumnName().toLowerCase()
+ "." + String.join(".", slotReference.getSubColPath()));
+ "." + String.join(".", slotReference.getSubPath()));
}
}
slotRef.setTable(table);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.rules.analysis.BindSink;
import org.apache.doris.nereids.rules.analysis.BindSlotWithPaths;
import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
Expand Down Expand Up @@ -136,7 +135,6 @@ private static List<RewriteJob> buildAnalyzeJobs(Optional<CustomTableResolver> c
new CheckPolicy()
),
bottomUp(new BindExpression()),
bottomUp(new BindSlotWithPaths()),
topDown(new BindSink()),
bottomUp(new CheckAfterBind()),
bottomUp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.doris.nereids.rules.rewrite.CheckMatchExpression;
import org.apache.doris.nereids.rules.rewrite.CheckMultiDistinct;
import org.apache.doris.nereids.rules.rewrite.CheckPrivileges;
import org.apache.doris.nereids.rules.rewrite.ClearContextStatus;
import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput;
import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
Expand Down Expand Up @@ -134,12 +135,15 @@
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoinProject;
import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning;
import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply;
import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply;
import org.apache.doris.nereids.rules.rewrite.mv.SelectMaterializedIndexWithAggregate;
import org.apache.doris.nereids.rules.rewrite.mv.SelectMaterializedIndexWithoutAggregate;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -398,9 +402,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("adjust preagg status",
topDown(new AdjustPreAggStatus())
),
topic("topn optimize",
topDown(new DeferMaterializeTopNResult())
),
topic("Point query short circuit",
topDown(new LogicalResultSinkToShortCircuitPointQuery())),
topic("eliminate",
Expand All @@ -413,22 +414,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new SumLiteralRewrite(),
new MergePercentileToArray())
),
topic("add projection for join",
custom(RuleType.ADD_PROJECT_FOR_JOIN, AddProjectForJoin::new),
topDown(new MergeProjects())
),
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check",
custom(RuleType.CHECK_DATA_TYPES, CheckDataTypes::new),
topDown(new PushDownFilterThroughProject(), new MergeProjects()),
custom(RuleType.ADJUST_CONJUNCTS_RETURN_TYPE, AdjustConjunctsReturnType::new),
bottomUp(
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
new CheckMatchExpression(),
new CheckMultiDistinct(),
new CheckAfterRewrite()
)
),
topic("Push project and filter on cte consumer to cte producer",
topDown(
new CollectFilterAboveConsumer(),
Expand Down Expand Up @@ -488,6 +473,39 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(List<RewriteJob> jobs) {
),
topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)),
topic("variant element_at push down",
custom(RuleType.VARIANT_SUB_PATH_PRUNING, VariantSubPathPruning::new),
custom(RuleType.CLEAR_CONTEXT_STATUS, ClearContextStatus::new),
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(jobs(
// after variant sub path pruning, we need do column pruning again
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(ImmutableList.of(
new PushDownFilterThroughProject(),
new MergeProjects()
)),
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new),
topic("topn optimize",
topDown(new DeferMaterializeTopNResult())
),
topic("add projection for join",
custom(RuleType.ADD_PROJECT_FOR_JOIN, AddProjectForJoin::new),
topDown(new MergeProjects())
),
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check",
custom(RuleType.CHECK_DATA_TYPES, CheckDataTypes::new),
topDown(new PushDownFilterThroughProject(), new MergeProjects()),
custom(RuleType.ADJUST_CONJUNCTS_RETURN_TYPE, AdjustConjunctsReturnType::new),
bottomUp(
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
new CheckMatchExpression(),
new CheckMultiDistinct(),
new CheckAfterRewrite()
)
),
topDown(new CollectCteConsumerOutput()))
))
),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@
*/
public class CommonSubExpressionOpt extends PlanPostProcessor {
@Override
public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {
public PhysicalProject<? extends Plan> visitPhysicalProject(
PhysicalProject<? extends Plan> project, CascadesContext ctx) {
project.child().accept(this, ctx);
if (!project.hasPushedDownToProjectionFunctions()) {
List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
project.getInputSlots(), project.getProjects());
project.setMultiLayerProjects(multiLayers);
}
List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
project.getInputSlots(), project.getProjects());
project.setMultiLayerProjects(multiLayers);
return project;
}

Expand Down
Loading

0 comments on commit d5617fe

Please sign in to comment.