Skip to content

Commit

Permalink
[refactor](Nereids) refactor push down element_at on variant
Browse files Browse the repository at this point in the history
intro a new rule VARIANT_SUB_PATH_PRUNING to prune variant sub path
  • Loading branch information
morrySnow committed Jun 19, 2024
1 parent ab764d2 commit 09d4cc3
Show file tree
Hide file tree
Showing 34 changed files with 1,047 additions and 698 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
Expand Down Expand Up @@ -64,7 +63,6 @@
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -131,15 +129,6 @@ public class StatementContext implements Closeable {
private final List<Expression> joinFilters = new ArrayList<>();

private final List<Hint> hints = new ArrayList<>();
// Root Slot -> Paths -> Sub-column Slots
private final Map<Slot, Map<List<String>, SlotReference>> subColumnSlotRefMap
= Maps.newHashMap();

// Map from rewritten slot to original expr
private final Map<Slot, Expression> subColumnOriginalExprMap = Maps.newHashMap();

// Map from original expr to rewritten slot
private final Map<Expression, Slot> originalExprToRewrittenSubColumn = Maps.newHashMap();

// Map slot to its relation, currently used in SlotReference to find its original
// Relation for example LogicalOlapScan
Expand Down Expand Up @@ -265,58 +254,10 @@ public Optional<SqlCacheContext> getSqlCacheContext() {
return Optional.ofNullable(sqlCacheContext);
}

public Set<SlotReference> getAllPathsSlots() {
Set<SlotReference> allSlotReferences = Sets.newHashSet();
for (Map<List<String>, SlotReference> slotReferenceMap : subColumnSlotRefMap.values()) {
allSlotReferences.addAll(slotReferenceMap.values());
}
return allSlotReferences;
}

public Expression getOriginalExpr(SlotReference rewriteSlot) {
return subColumnOriginalExprMap.getOrDefault(rewriteSlot, null);
}

public Slot getRewrittenSlotRefByOriginalExpr(Expression originalExpr) {
return originalExprToRewrittenSubColumn.getOrDefault(originalExpr, null);
}

/**
* Add a slot ref attached with paths in context to avoid duplicated slot
*/
public void addPathSlotRef(Slot root, List<String> paths, SlotReference slotRef, Expression originalExpr) {
subColumnSlotRefMap.computeIfAbsent(root, k -> Maps.newTreeMap((lst1, lst2) -> {
Iterator<String> it1 = lst1.iterator();
Iterator<String> it2 = lst2.iterator();
while (it1.hasNext() && it2.hasNext()) {
int result = it1.next().compareTo(it2.next());
if (result != 0) {
return result;
}
}
return Integer.compare(lst1.size(), lst2.size());
}));
subColumnSlotRefMap.get(root).put(paths, slotRef);
subColumnOriginalExprMap.put(slotRef, originalExpr);
originalExprToRewrittenSubColumn.put(originalExpr, slotRef);
}

public SlotReference getPathSlot(Slot root, List<String> paths) {
Map<List<String>, SlotReference> pathsSlotsMap = subColumnSlotRefMap.getOrDefault(root, null);
if (pathsSlotsMap == null) {
return null;
}
return pathsSlotsMap.getOrDefault(paths, null);
}

public void addSlotToRelation(Slot slot, Relation relation) {
slotToRelation.put(slot, relation);
}

public Relation getRelationBySlot(Slot slot) {
return slotToRelation.getOrDefault(slot, null);
}

public boolean isDpHyp() {
return isDpHyp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HighOrderFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.PushDownToProjectionFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf;
Expand All @@ -102,7 +101,6 @@
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TFunctionBinaryType;

import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -211,20 +209,6 @@ private OlapTable getOlapTableDirectly(SlotRef left) {

@Override
public Expr visitElementAt(ElementAt elementAt, PlanTranslatorContext context) {
if (PushDownToProjectionFunction.validToPushDown(elementAt)) {
if (ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable() != null
&& !ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot()) {
throw new AnalysisException(
"set enable_rewrite_element_at_to_slot=true when using element_at function for variant type");
}
SlotReference rewrittenSlot = (SlotReference) context.getConnectContext()
.getStatementContext().getRewrittenSlotRefByOriginalExpr(elementAt);
// rewrittenSlot == null means variant is not from table. so keep element_at function
if (rewrittenSlot != null) {
return context.findSlotRef(rewrittenSlot.getExprId());
}
}
return visitScalarFunction(elementAt, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.scalar.PushDownToProjectionFunction;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.AggPhase;
Expand Down Expand Up @@ -1250,8 +1249,7 @@ public PlanFragment visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, P
}
if (planNode instanceof ExchangeNode || planNode instanceof SortNode || planNode instanceof UnionNode
// this means we have filter->limit->project, need a SelectNode
|| (child instanceof PhysicalProject
&& !((PhysicalProject<?>) child).hasPushedDownToProjectionFunctions())) {
|| child instanceof PhysicalProject) {
// the three nodes don't support conjuncts, need create a SelectNode to filter data
SelectNode selectNode = new SelectNode(context.nextPlanNodeId(), planNode);
selectNode.setNereidsId(filter.getId());
Expand Down Expand Up @@ -1833,35 +1831,6 @@ && findOlapScanNodesByPassExchangeAndJoinNode(inputFragment.getPlanRoot())) {
return inputFragment;
}

// collect all valid PushDownToProjectionFunction from expression
private List<Expression> getPushDownToProjectionFunctionForRewritten(NamedExpression expression) {
List<Expression> targetExprList = expression.collectToList(PushDownToProjectionFunction.class::isInstance);
return targetExprList.stream()
.filter(PushDownToProjectionFunction::validToPushDown)
.collect(Collectors.toList());
}

// register rewritten slots from original PushDownToProjectionFunction
private void registerRewrittenSlot(PhysicalProject<? extends Plan> project, OlapScanNode olapScanNode) {
// register slots that are rewritten from element_at/etc..
List<Expression> allPushDownProjectionFunctions = project.getProjects().stream()
.map(this::getPushDownToProjectionFunctionForRewritten)
.flatMap(List::stream)
.collect(Collectors.toList());
for (Expression expr : allPushDownProjectionFunctions) {
PushDownToProjectionFunction function = (PushDownToProjectionFunction) expr;
if (context != null
&& context.getConnectContext() != null
&& context.getConnectContext().getStatementContext() != null) {
Slot argumentSlot = function.getInputSlots().stream().findFirst().get();
Expression rewrittenSlot = PushDownToProjectionFunction.rewriteToSlot(
function, (SlotReference) argumentSlot);
TupleDescriptor tupleDescriptor = context.getTupleDesc(olapScanNode.getTupleId());
context.createSlotDesc(tupleDescriptor, (SlotReference) rewrittenSlot);
}
}
}

// TODO: generate expression mapping when be project could do in ExecNode.
@Override
public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project, PlanTranslatorContext context) {
Expand All @@ -1876,12 +1845,6 @@ public PlanFragment visitPhysicalProject(PhysicalProject<? extends Plan> project

PlanFragment inputFragment = project.child(0).accept(this, context);

if (inputFragment.getPlanRoot() instanceof OlapScanNode) {
// function already pushed down in projection
// e.g. select count(distinct cast(element_at(v, 'a') as int)) from tbl;
registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot());
}

PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Expr> projectionExprs = null;
List<Expr> allProjectionExprs = Lists.newArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,12 @@ public SlotDescriptor createSlotDesc(TupleDescriptor tupleDesc, SlotReference sl
slotDescriptor.setLabel(slotReference.getName());
} else {
slotRef = new SlotRef(slotDescriptor);
if (slotReference.hasSubColPath()) {
slotDescriptor.setSubColLables(slotReference.getSubColPath());
if (slotReference.hasSubColPath() && slotReference.getColumn().isPresent()) {
slotDescriptor.setSubColLables(slotReference.getSubPath());
// use lower case name for variant's root, since backend treat parent column as lower case
// see issue: https://github.com/apache/doris/pull/32999/commits
slotDescriptor.setMaterializedColumnName(slotRef.getColumnName().toLowerCase()
+ "." + String.join(".", slotReference.getSubColPath()));
+ "." + String.join(".", slotReference.getSubPath()));
}
}
slotRef.setTable(table);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.rules.analysis.BindSink;
import org.apache.doris.nereids.rules.analysis.BindSlotWithPaths;
import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
Expand Down Expand Up @@ -136,7 +135,6 @@ private static List<RewriteJob> buildAnalyzeJobs(Optional<CustomTableResolver> c
new CheckPolicy()
),
bottomUp(new BindExpression()),
bottomUp(new BindSlotWithPaths()),
topDown(new BindSink()),
bottomUp(new CheckAfterBind()),
bottomUp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoinProject;
import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning;
import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply;
import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply;
Expand Down Expand Up @@ -398,9 +399,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("adjust preagg status",
topDown(new AdjustPreAggStatus())
),
topic("topn optimize",
topDown(new DeferMaterializeTopNResult())
),
topic("Point query short circuit",
topDown(new LogicalResultSinkToShortCircuitPointQuery())),
topic("eliminate",
Expand Down Expand Up @@ -488,6 +486,16 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(List<RewriteJob> jobs) {
),
topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)),
topic("variant element_at push down",
custom(RuleType.VARIANT_SUB_PATH_PRUNING, VariantSubPathPruning::new),
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(jobs(
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
topic("topn optimize",
topDown(new DeferMaterializeTopNResult())
),
topDown(new CollectCteConsumerOutput()))
))
),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@
*/
public class CommonSubExpressionOpt extends PlanPostProcessor {
@Override
public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {
public PhysicalProject<? extends Plan> visitPhysicalProject(
PhysicalProject<? extends Plan> project, CascadesContext ctx) {
project.child().accept(this, ctx);
if (!project.hasPushedDownToProjectionFunctions()) {
List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
project.getInputSlots(), project.getProjects());
project.setMultiLayerProjects(multiLayers);
}
List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
project.getInputSlots(), project.getProjects());
project.setMultiLayerProjects(multiLayers);
return project;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand All @@ -36,14 +37,10 @@ public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesC
}

PhysicalProject<? extends Plan> project = (PhysicalProject<? extends Plan>) child;
if (project.hasPushedDownToProjectionFunctions()) {
// ignore project which is pulled up from LogicalOlapScan
return filter;
}
PhysicalFilter<? extends Plan> newFilter = filter.withConjunctsAndChild(
ExpressionUtils.replace(filter.getConjuncts(), project.getAliasToProducer()),
project.child());
return ((PhysicalProject) project.withChildren(newFilter.accept(this, context)))
return ((AbstractPhysicalPlan) project.withChildren(newFilter.accept(this, context)))
.copyStatsAndGroupIdFrom(project);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesC

Plan child = filter.child();
// Forbidden filter-project, we must make filter-project -> project-filter.
if (child instanceof PhysicalProject && !((PhysicalProject<?>) child).hasPushedDownToProjectionFunctions()) {
if (child instanceof PhysicalProject) {
throw new AnalysisException(
"Nereids generate a filter-project plan, but backend not support:\n" + filter.treeString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public enum RuleType {
BINDING_SET_OPERATION_SLOT(RuleTypeClass.REWRITE),
BINDING_INLINE_TABLE_SLOT(RuleTypeClass.REWRITE),

BINDING_SLOT_WITH_PATHS_SCAN(RuleTypeClass.REWRITE),
COUNT_LITERAL_REWRITE(RuleTypeClass.REWRITE),
SUM_LITERAL_REWRITE(RuleTypeClass.REWRITE),
REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE),
Expand Down Expand Up @@ -179,6 +178,7 @@ public enum RuleType {
PUSH_DOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),

ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE),
VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE),

COLUMN_PRUNING(RuleTypeClass.REWRITE),
ELIMINATE_SORT(RuleTypeClass.REWRITE),
Expand Down Expand Up @@ -271,13 +271,11 @@ public enum RuleType {

OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),

OLAP_SCAN_WITH_PROJECT_PARTITION_PRUNE(RuleTypeClass.REWRITE),
FILE_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
PUSH_CONJUNCTS_INTO_JDBC_SCAN(RuleTypeClass.REWRITE),
PUSH_CONJUNCTS_INTO_ODBC_SCAN(RuleTypeClass.REWRITE),
PUSH_CONJUNCTS_INTO_ES_SCAN(RuleTypeClass.REWRITE),
OLAP_SCAN_TABLET_PRUNE(RuleTypeClass.REWRITE),
OLAP_SCAN_WITH_PROJECT_TABLET_PRUNE(RuleTypeClass.REWRITE),
PUSH_AGGREGATE_TO_OLAP_SCAN(RuleTypeClass.REWRITE),
EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE),
HIDE_ONE_ROW_RELATION_UNDER_UNION(RuleTypeClass.REWRITE),
Expand Down Expand Up @@ -320,8 +318,6 @@ public enum RuleType {
// adjust nullable
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE),
// ensure having project on the top join
ENSURE_PROJECT_ON_TOP_JOIN(RuleTypeClass.REWRITE),

PULL_UP_CTE_ANCHOR(RuleTypeClass.REWRITE),
CTE_INLINE(RuleTypeClass.REWRITE),
Expand Down
Loading

0 comments on commit 09d4cc3

Please sign in to comment.