From c0ed5f74ad28761280d124b7176d727a4d233b51 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Mon, 6 Nov 2023 11:07:52 +0800 Subject: [PATCH] [fix](Nereids) storage later agg rule process agg children by mistake (#26101) update Project#findProject agg function's children could be any expression rather than only slot. we use Project#findProject to process them. But this util could only process slot. This PR update this util to let it could process all type expression. --- .../implementation/AggregateStrategies.java | 6 ++-- .../rules/rewrite/EliminateAggregate.java | 4 +-- .../nereids/trees/plans/algebra/Project.java | 31 ++++++++++--------- .../doris/nereids/util/ExpressionUtils.java | 2 +- .../aggregate/aggregate_decimal256.groovy | 1 + 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index a020656c686469..6521929f164c13 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -397,7 +397,7 @@ private LogicalAggregate storageLayerAggregate( if (project != null) { argumentsOfAggregateFunction = Project.findProject( - (List) (List) argumentsOfAggregateFunction, project.getProjects()) + argumentsOfAggregateFunction, project.getProjects()) .stream() .map(p -> p instanceof Alias ? p.child(0) : p) .collect(ImmutableList.toImmutableList()); @@ -431,8 +431,8 @@ private LogicalAggregate storageLayerAggregate( Set aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction, SlotReference.class::isInstance); - List usedSlotInTable = (List) (List) Project.findProject(aggUsedSlots, - (List) (List) logicalScan.getOutput()); + List usedSlotInTable = (List) Project.findProject(aggUsedSlots, + logicalScan.getOutput()); for (SlotReference slot : usedSlotInTable) { Column column = slot.getColumn().get(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggregate.java index 88078246ddf058..a848fcdc3f5c7f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggregate.java @@ -41,8 +41,8 @@ public Rule build() { if (!onlyHasSlots(outerAgg.getOutputExpressions())) { return outerAgg; } - List prunedInnerAggOutput = Project.findProject(outerAgg.getOutputSet(), - innerAgg.getOutputExpressions()); + List prunedInnerAggOutput = (List) Project.findProject( + outerAgg.getOutputSet(), innerAgg.getOutputExpressions()); return innerAgg.withAggOutput(prunedInnerAggOutput); }).toRule(RuleType.ELIMINATE_AGGREGATE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java index 5b214ad0900379..1f1573fd65bdd5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java @@ -23,9 +23,9 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collection; @@ -78,22 +78,25 @@ default List mergeProjections(Project childProject) { /** * find projects, if not found the slot, then throw AnalysisException */ - static List findProject( - Collection slotReferences, - List projects) throws AnalysisException { + static List findProject( + Collection expressions, + List projects) throws AnalysisException { Map exprIdToProject = projects.stream() - .collect(ImmutableMap.toImmutableMap(p -> p.getExprId(), p -> p)); + .collect(ImmutableMap.toImmutableMap(NamedExpression::getExprId, p -> p)); - return slotReferences.stream() - .map(slot -> { - ExprId exprId = slot.getExprId(); - NamedExpression project = exprIdToProject.get(exprId); - if (project == null) { - throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects); + return ExpressionUtils.rewriteDownShortCircuit(expressions, + expr -> { + if (expr instanceof Slot) { + Slot slot = (Slot) expr; + ExprId exprId = slot.getExprId(); + NamedExpression project = exprIdToProject.get(exprId); + if (project == null) { + throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects); + } + return project; } - return project; - }) - .collect(ImmutableList.toImmutableList()); + return expr; + }); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 1e67808c614ed3..0c5faa20957666 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -282,7 +282,7 @@ public static Set replace(Set exprs, } public static List rewriteDownShortCircuit( - List exprs, Function rewriteFunction) { + Collection exprs, Function rewriteFunction) { return exprs.stream() .map(expr -> (E) expr.rewriteDownShortCircuit(rewriteFunction)) .collect(ImmutableList.toImmutableList()); diff --git a/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy b/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy index 8ea45c68bf9d31..88121ebb14593a 100644 --- a/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy +++ b/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy @@ -17,6 +17,7 @@ suite("aggregate_decimal256") { sql "set enable_nereids_planner = true;" + sql "set enable_fallback_to_original_planner = false;" sql "set enable_decimal256 = true;" sql "drop table if exists test_aggregate_decimal256_sum;" sql """ create table test_aggregate_decimal256_sum(k1 int, v1 decimal(38, 6), v2 decimal(38, 6))