Skip to content

Commit

Permalink
[fix](Nereids) storage later agg rule process agg children by mistake (
Browse files Browse the repository at this point in the history
…#26101)

update Project#findProject
agg function's children could be any expression rather than only slot.
we use Project#findProject to process them. But this util could only
process slot. This PR update this util to let it could process all type
expression.
  • Loading branch information
morrySnow authored Nov 6, 2023
1 parent 1e2a614 commit c0ed5f7
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ private LogicalAggregate<? extends Plan> storageLayerAggregate(

if (project != null) {
argumentsOfAggregateFunction = Project.findProject(
(List<SlotReference>) (List) argumentsOfAggregateFunction, project.getProjects())
argumentsOfAggregateFunction, project.getProjects())
.stream()
.map(p -> p instanceof Alias ? p.child(0) : p)
.collect(ImmutableList.toImmutableList());
Expand Down Expand Up @@ -431,8 +431,8 @@ private LogicalAggregate<? extends Plan> storageLayerAggregate(
Set<SlotReference> aggUsedSlots =
ExpressionUtils.collect(argumentsOfAggregateFunction, SlotReference.class::isInstance);

List<SlotReference> usedSlotInTable = (List<SlotReference>) (List) Project.findProject(aggUsedSlots,
(List<NamedExpression>) (List) logicalScan.getOutput());
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
logicalScan.getOutput());

for (SlotReference slot : usedSlotInTable) {
Column column = slot.getColumn().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public Rule build() {
if (!onlyHasSlots(outerAgg.getOutputExpressions())) {
return outerAgg;
}
List<NamedExpression> prunedInnerAggOutput = Project.findProject(outerAgg.getOutputSet(),
innerAgg.getOutputExpressions());
List<NamedExpression> prunedInnerAggOutput = (List<NamedExpression>) Project.findProject(
outerAgg.getOutputSet(), innerAgg.getOutputExpressions());
return innerAgg.withAggOutput(prunedInnerAggOutput);
}).toRule(RuleType.ELIMINATE_AGGREGATE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.Collection;
Expand Down Expand Up @@ -78,22 +78,25 @@ default List<NamedExpression> mergeProjections(Project childProject) {
/**
* find projects, if not found the slot, then throw AnalysisException
*/
static List<NamedExpression> findProject(
Collection<? extends Slot> slotReferences,
List<NamedExpression> projects) throws AnalysisException {
static List<? extends Expression> findProject(
Collection<? extends Expression> expressions,
List<? extends NamedExpression> projects) throws AnalysisException {
Map<ExprId, NamedExpression> exprIdToProject = projects.stream()
.collect(ImmutableMap.toImmutableMap(p -> p.getExprId(), p -> p));
.collect(ImmutableMap.toImmutableMap(NamedExpression::getExprId, p -> p));

return slotReferences.stream()
.map(slot -> {
ExprId exprId = slot.getExprId();
NamedExpression project = exprIdToProject.get(exprId);
if (project == null) {
throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects);
return ExpressionUtils.rewriteDownShortCircuit(expressions,
expr -> {
if (expr instanceof Slot) {
Slot slot = (Slot) expr;
ExprId exprId = slot.getExprId();
NamedExpression project = exprIdToProject.get(exprId);
if (project == null) {
throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects);
}
return project;
}
return project;
})
.collect(ImmutableList.toImmutableList());
return expr;
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ public static Set<Expression> replace(Set<Expression> exprs,
}

public static <E extends Expression> List<E> rewriteDownShortCircuit(
List<E> exprs, Function<Expression, Expression> rewriteFunction) {
Collection<E> exprs, Function<Expression, Expression> rewriteFunction) {
return exprs.stream()
.map(expr -> (E) expr.rewriteDownShortCircuit(rewriteFunction))
.collect(ImmutableList.toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

suite("aggregate_decimal256") {
sql "set enable_nereids_planner = true;"
sql "set enable_fallback_to_original_planner = false;"
sql "set enable_decimal256 = true;"
sql "drop table if exists test_aggregate_decimal256_sum;"
sql """ create table test_aggregate_decimal256_sum(k1 int, v1 decimal(38, 6), v2 decimal(38, 6))
Expand Down

0 comments on commit c0ed5f7

Please sign in to comment.