diff --git a/docs/changelog/100642.yaml b/docs/changelog/100642.yaml new file mode 100644 index 0000000000000..805a20174e11d --- /dev/null +++ b/docs/changelog/100642.yaml @@ -0,0 +1,6 @@ +pr: 100642 +summary: "ESQL: Alias duplicated aggregations in a stats" +area: ES|QL +type: enhancement +issues: + - 100544 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index d671ba6ec13b1..acf42d908ed66 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -681,3 +681,19 @@ c:l | job_positions:s 4 |Reporting Analyst 4 |Tech Lead ; + +duplicateAggregationsWithoutGrouping +from employees | eval x = salary | stats c = count(), m = min(x), m1 = min(salary), c1 = count(1); + +c:l | m:i | m1:i | c1:l +100 | 25324 | 25324 | 100 +; + +duplicateAggregationsWithGrouping +from employees | eval x = salary | stats c = count(), m = min(x), m1 = min(salary), c1 = count(1) by gender | sort gender; + +c:l| m:i | m1:i | c1:l| gender:s +33 | 25976 | 25976 | 33 | F +57 | 25945 | 25945 | 57 | M +10 | 25324 | 25324 | 10 | null +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 23fa051c1d7a2..ed215efc4c066 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; @@ -57,8 +58,10 @@ import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.ql.rule.Rule; import org.elasticsearch.xpack.ql.rule.RuleExecutor; +import org.elasticsearch.xpack.ql.type.DataTypes; import org.elasticsearch.xpack.ql.util.CollectionUtils; import org.elasticsearch.xpack.ql.util.Holder; +import org.elasticsearch.xpack.ql.util.StringUtils; import java.time.ZoneId; import java.util.ArrayList; @@ -95,13 +98,14 @@ protected static List> rules() { new SubstituteSurrogates(), new ReplaceRegexMatch(), new ReplaceAliasingEvalWithProject() - // new ReplaceTextFieldAttributesWithTheKeywordSubfield() + // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634 ); var operators = new Batch<>( "Operator Optimization", new CombineProjections(), new CombineEvals(), + new ReplaceDuplicateAggWithEval(), new PruneEmptyPlans(), new PropagateEmptyRelation(), new ConvertStringToByteRef(), @@ -947,4 +951,127 @@ private LogicalPlan rule(Eval eval) { } } + + /** + * Normalize aggregation functions by: + * 1. replaces reference to field attributes with their source + * 2. in case of Count, aligns the various forms (Count(1), Count(0), Count(), Count(*)) to Count(*) + */ + // TODO waiting on https://github.com/elastic/elasticsearch/issues/100634 + static class NormalizeAggregate extends Rule { + + @Override + public LogicalPlan apply(LogicalPlan plan) { + AttributeMap aliases = new AttributeMap<>(); + + // traverse the tree bottom-up + // 1. if it's Aggregate, normalize the aggregates + // regardless, collect the attributes but only if they refer to an attribute or literal + plan = plan.transformUp(p -> { + if (p instanceof Aggregate agg) { + p = normalize(agg, aliases); + } + p.forEachExpression(Alias.class, a -> { + var child = a.child(); + if (child.foldable() || child instanceof NamedExpression) { + aliases.putIfAbsent(a.toAttribute(), child); + } + }); + + return p; + }); + return plan; + } + + private static LogicalPlan normalize(Aggregate aggregate, AttributeMap aliases) { + var aggs = aggregate.aggregates(); + List newAggs = new ArrayList<>(aggs.size()); + boolean changed = false; + + for (NamedExpression agg : aggs) { + if (agg instanceof Alias as && as.child() instanceof AggregateFunction af) { + // replace field reference + if (af.field() instanceof NamedExpression ne) { + Attribute attr = ne.toAttribute(); + var resolved = aliases.resolve(attr, attr); + if (resolved != attr) { + changed = true; + var newChildren = CollectionUtils.combine(Collections.singletonList(resolved), af.parameters()); + // update the reference so Count can pick it up + af = (AggregateFunction) af.replaceChildren(newChildren); + agg = as.replaceChild(af); + } + } + // handle Count(*) + if (af instanceof Count count) { + var field = af.field(); + if (field.foldable()) { + var fold = field.fold(); + if (fold != null && StringUtils.WILDCARD.equals(fold) == false) { + changed = true; + var source = count.source(); + agg = as.replaceChild(new Count(source, new Literal(source, StringUtils.WILDCARD, DataTypes.KEYWORD))); + } + } + } + } + newAggs.add(agg); + } + return changed ? new Aggregate(aggregate.source(), aggregate.child(), aggregate.groupings(), newAggs) : aggregate; + } + } + + /** + * Replace aggregations that are duplicated inside an Aggregate with an Eval to avoid duplicated compute. + * stats a = min(x), b = min(x), c = count(*), d = count() by g + * becomes + * stats a = min(x), c = count(*) by g + * eval b = a, d = c + * keep a, b, c, d, g + */ + static class ReplaceDuplicateAggWithEval extends OptimizerRules.OptimizerRule { + + ReplaceDuplicateAggWithEval() { + super(TransformDirection.UP); + } + + @Override + protected LogicalPlan rule(Aggregate aggregate) { + LogicalPlan plan = aggregate; + + boolean foundDuplicate = false; + var aggs = aggregate.aggregates(); + Map seenAggs = Maps.newMapWithExpectedSize(aggs.size()); + List projections = new ArrayList<>(); + List keptAggs = new ArrayList<>(aggs.size()); + + for (NamedExpression agg : aggs) { + var attr = agg.toAttribute(); + if (agg instanceof Alias as && as.child() instanceof AggregateFunction af) { + var seen = seenAggs.putIfAbsent(af, attr); + if (seen != null) { + foundDuplicate = true; + projections.add(as.replaceChild(seen)); + } + // otherwise keep the agg in place + else { + keptAggs.add(agg); + projections.add(attr); + } + } else { + keptAggs.add(agg); + projections.add(attr); + } + } + + // at least one duplicate found - add the projection (to keep the output in place) + if (foundDuplicate) { + var source = aggregate.source(); + var newAggregate = new Aggregate(source, aggregate.child(), aggregate.groupings(), keptAggs); + plan = new Project(source, newAggregate, projections); + } + + return plan; + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 511a7ee08b5e1..72e12697488df 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.esql.plan.physical.LimitExec; import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.planner.FilterTests; import org.elasticsearch.xpack.esql.planner.Mapper; import org.elasticsearch.xpack.esql.planner.PlannerUtils; @@ -41,6 +42,7 @@ import org.elasticsearch.xpack.esql.stats.Metrics; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.elasticsearch.xpack.esql.type.EsqlDataTypes; +import org.elasticsearch.xpack.ql.expression.Alias; import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.ql.index.EsIndex; @@ -299,6 +301,16 @@ public void testAnotherCountAllWithFilter() { assertThat(expected.toString(), is(esStatsQuery.query().toString())); } + /** + * Expected + * ProjectExec[[c{r}#3, c{r}#3 AS call, c_literal{r}#7]] + * \_LimitExec[500[INTEGER]] + * \_AggregateExec[[],[COUNT([2a][KEYWORD]) AS c, COUNT(1[INTEGER]) AS c_literal],FINAL,null] + * \_ExchangeExec[[count{r}#18, seen{r}#19, count{r}#20, seen{r}#21],true] + * \_EsStatsQueryExec[test], stats[Stat[name=*, type=COUNT, query=null], Stat[name=*, type=COUNT, query=null]]], + * query[{"esql_single_value":{"field":"emp_no","next":{"range":{"emp_no":{"gt":10010,"boost":1.0}}}}}] + * [count{r}#23, seen{r}#24, count{r}#25, seen{r}#26], limit[], + */ public void testMultiCountAllWithFilter() { var plan = plan(""" from test @@ -306,14 +318,19 @@ public void testMultiCountAllWithFilter() { | stats c = count(), call = count(*), c_literal = count(1) """, IS_SV_STATS); - var limit = as(plan, LimitExec.class); + var project = as(plan, ProjectExec.class); + var projections = project.projections(); + assertThat(Expressions.names(projections), contains("c", "call", "c_literal")); + var alias = as(projections.get(1), Alias.class); + assertThat(Expressions.name(alias.child()), is("c")); + var limit = as(project.child(), LimitExec.class); var agg = as(limit.child(), AggregateExec.class); assertThat(agg.getMode(), is(FINAL)); - assertThat(Expressions.names(agg.aggregates()), contains("c", "call", "c_literal")); + assertThat(Expressions.names(agg.aggregates()), contains("c", "c_literal")); var exchange = as(agg.child(), ExchangeExec.class); var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class); assertThat(esStatsQuery.limit(), is(nullValue())); - assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen", "count", "seen", "count", "seen")); + assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen", "count", "seen")); var expected = wrapWithSingleQuery(QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no"); assertThat(expected.toString(), is(esStatsQuery.query().toString())); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index a22bb3b91ff0b..285ad7021e83f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; @@ -1923,6 +1924,147 @@ public void testPruneRenameOnAggBy() { var source = as(agg.child(), EsRelation.class); } + /** + * Expects + * Project[[c1{r}#2, c2{r}#4, cs{r}#6, cm{r}#8, cexp{r}#10]] + * \_Eval[[c1{r}#2 AS c2, c1{r}#2 AS cs, c1{r}#2 AS cm, c1{r}#2 AS cexp]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[],[COUNT([2a][KEYWORD]) AS c1]] + * \_EsRelation[test][_meta_field{f}#17, emp_no{f}#11, first_name{f}#12, ..] + */ + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/100634") + public void testEliminateDuplicateAggsCountAll() { + var plan = plan(""" + from test + | stats c1 = count(1), c2 = count(2), cs = count(*), cm = count(), cexp = count("123") + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("c1", "c2", "cs", "cm", "cexp")); + var eval = as(project.child(), Eval.class); + var fields = eval.fields(); + assertThat(Expressions.names(fields), contains("c2", "cs", "cm", "cexp")); + for (Alias field : fields) { + assertThat(Expressions.name(field.child()), is("c1")); + } + var limit = as(eval.child(), Limit.class); + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + assertThat(Expressions.names(aggs), contains("c1")); + aggFieldName(aggs.get(0), Count.class, "*"); + var source = as(agg.child(), EsRelation.class); + } + + /** + * Expects + * Project[[c1{r}#7, cx{r}#10, cs{r}#12, cy{r}#15]] + * \_Eval[[c1{r}#7 AS cx, c1{r}#7 AS cs, c1{r}#7 AS cy]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[],[COUNT([2a][KEYWORD]) AS c1]] + * \_EsRelation[test][_meta_field{f}#22, emp_no{f}#16, first_name{f}#17, ..] + */ + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/100634") + public void testEliminateDuplicateAggsWithAliasedFields() { + var plan = plan(""" + from test + | eval x = 1 + | eval y = x + | stats c1 = count(1), cx = count(x), cs = count(*), cy = count(y) + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("c1", "cx", "cs", "cy")); + var eval = as(project.child(), Eval.class); + var fields = eval.fields(); + assertThat(Expressions.names(fields), contains("cx", "cs", "cy")); + for (Alias field : fields) { + assertThat(Expressions.name(field.child()), is("c1")); + } + var limit = as(eval.child(), Limit.class); + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + assertThat(Expressions.names(aggs), contains("c1")); + aggFieldName(aggs.get(0), Count.class, "*"); + var source = as(agg.child(), EsRelation.class); + } + + /** + * Expects + * Project[[min{r}#1385, max{r}#1388, min{r}#1385 AS min2, max{r}#1388 AS max2, gender{f}#1398]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[gender{f}#1398],[MIN(salary{f}#1401) AS min, MAX(salary{f}#1401) AS max, gender{f}#1398]] + * \_EsRelation[test][_meta_field{f}#1402, emp_no{f}#1396, first_name{f}#..] + */ + public void testEliminateDuplicateAggsMixed() { + var plan = plan(""" + from test + | stats min = min(salary), max = max(salary), min2 = min(salary), max2 = max(salary) by gender + """); + + var project = as(plan, Project.class); + var projections = project.projections(); + assertThat(Expressions.names(projections), contains("min", "max", "min2", "max2", "gender")); + as(projections.get(0), ReferenceAttribute.class); + as(projections.get(1), ReferenceAttribute.class); + assertThat(Expressions.name(aliased(projections.get(2), ReferenceAttribute.class)), is("min")); + assertThat(Expressions.name(aliased(projections.get(3), ReferenceAttribute.class)), is("max")); + + var limit = as(project.child(), Limit.class); + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + assertThat(Expressions.names(aggs), contains("min", "max", "gender")); + aggFieldName(aggs.get(0), Min.class, "salary"); + aggFieldName(aggs.get(1), Max.class, "salary"); + var source = as(agg.child(), EsRelation.class); + } + + /** + * Expects + * EsqlProject[[a{r}#5, c{r}#8]] + * \_Eval[[null[INTEGER] AS x]] + * \_EsRelation[test][_meta_field{f}#15, emp_no{f}#9, first_name{f}#10, g..] + */ + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/100634") + public void testEliminateDuplicateAggWithNull() { + var plan = plan(""" + from test + | eval x = null + 1 + | stats a = avg(x), c = count(x) + """); + fail("Awaits fix"); + } + + /** + * Expects + * Project[[max(x){r}#11, max(x){r}#11 AS max(y), max(x){r}#11 AS max(z)]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[],[MAX(salary{f}#21) AS max(x)]] + * \_EsRelation[test][_meta_field{f}#22, emp_no{f}#16, first_name{f}#17, ..] + */ + public void testEliminateDuplicateAggsNonCount() { + var plan = plan(""" + from test + | eval x = salary + | eval y = x + | eval z = y + | stats max(x), max(y), max(z) + """); + + var project = as(plan, Project.class); + var projections = project.projections(); + assertThat(Expressions.names(projections), contains("max(x)", "max(y)", "max(z)")); + as(projections.get(0), ReferenceAttribute.class); + assertThat(Expressions.name(aliased(projections.get(1), ReferenceAttribute.class)), is("max(x)")); + assertThat(Expressions.name(aliased(projections.get(2), ReferenceAttribute.class)), is("max(x)")); + + var limit = as(project.child(), Limit.class); + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + assertThat(Expressions.names(aggs), contains("max(x)")); + aggFieldName(aggs.get(0), Max.class, "salary"); + var source = as(agg.child(), EsRelation.class); + } + private T aliased(Expression exp, Class clazz) { var alias = as(exp, Alias.class); return as(alias.child(), clazz); @@ -1932,7 +2074,8 @@ private void aggFieldName(Expression exp, Class var alias = as(exp, Alias.class); var af = as(alias.child(), aggType); var field = af.field(); - assertThat(Expressions.name(field), is(fieldName)); + var name = field.foldable() ? BytesRefs.toString(field.fold()) : Expressions.name(field); + assertThat(name, is(fieldName)); } private LogicalPlan optimizedPlan(String query) {