diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index 05a6cec51284f..e20b9a987f5ef 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -23,7 +23,10 @@ import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -31,8 +34,11 @@ import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.oneOf; /** * Base class for aggregation tests. @@ -47,7 +53,43 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa */ protected static Iterable parameterSuppliersFromTypedDataWithDefaultChecks(List suppliers) { // TODO: Add case with no input expecting null - return parameterSuppliersFromTypedData(randomizeBytesRefsOffset(suppliers)); + return parameterSuppliersFromTypedData(withNoRowsExpectingNull(randomizeBytesRefsOffset(suppliers))); + } + + /** + * Adds a test case with no rows, expecting null, to the list of suppliers. + */ + protected static List withNoRowsExpectingNull(List suppliers) { + List newSuppliers = new ArrayList<>(suppliers); + Set> uniqueSignatures = new HashSet<>(); + + for (TestCaseSupplier original : suppliers) { + if (uniqueSignatures.add(original.types())) { + newSuppliers.add(new TestCaseSupplier(original.name() + " with no rows", original.types(), () -> { + var testCase = original.get(); + + if (testCase.getData().stream().noneMatch(TestCaseSupplier.TypedData::isMultiRow)) { + // Fail if no multi-row data, at least until a real case is found + fail("No multi-row data found in test case: " + testCase); + } + + var newData = testCase.getData().stream().map(td -> td.isMultiRow() ? td.withData(List.of()) : td).toList(); + + return new TestCaseSupplier.TestCase( + newData, + testCase.evaluatorToString(), + testCase.expectedType(), + nullValue(), + null, + testCase.getExpectedTypeError(), + null, + null + ); + })); + } + } + + return newSuppliers; } public void testAggregate() { @@ -56,6 +98,12 @@ public void testAggregate() { resolveExpression(expression, this::aggregateSingleMode, this::evaluate); } + public void testAggregateIntermediate() { + Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase); + + resolveExpression(expression, this::aggregateWithIntermediates, this::evaluate); + } + public void testFold() { Expression expression = buildLiteralExpression(testCase); @@ -80,17 +128,78 @@ public void testFold() { }); } - private void aggregateSingleMode(AggregatorFunctionSupplier aggregatorFunctionSupplier) { + private void aggregateSingleMode(Expression expression) { + Object result; + try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) { + Page inputPage = rows(testCase.getMultiRowFields()); + try { + aggregator.processPage(inputPage); + } finally { + inputPage.releaseBlocks(); + } + + result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType())); + } + + assertThat(result, not(equalTo(Double.NaN))); + assert testCase.getMatcher().matches(Double.POSITIVE_INFINITY) == false; + assertThat(result, not(equalTo(Double.POSITIVE_INFINITY))); + assert testCase.getMatcher().matches(Double.NEGATIVE_INFINITY) == false; + assertThat(result, not(equalTo(Double.NEGATIVE_INFINITY))); + assertThat(result, testCase.getMatcher()); + if (testCase.getExpectedWarnings() != null) { + assertWarnings(testCase.getExpectedWarnings()); + } + } + + private void aggregateWithIntermediates(Expression expression) { + int intermediateBlockOffset = randomIntBetween(0, 10); + Block[] intermediateBlocks; + int intermediateStates; + + // Input rows to intermediate states + try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.INITIAL)) { + intermediateStates = aggregator.evaluateBlockCount(); + + int intermediateBlockExtraSize = randomIntBetween(0, 10); + intermediateBlocks = new Block[intermediateBlockOffset + intermediateStates + intermediateBlockExtraSize]; + + Page inputPage = rows(testCase.getMultiRowFields()); + try { + aggregator.processPage(inputPage); + } finally { + inputPage.releaseBlocks(); + } + + aggregator.evaluate(intermediateBlocks, intermediateBlockOffset, driverContext()); + + int positionCount = intermediateBlocks[intermediateBlockOffset].getPositionCount(); + + // Fill offset and extra blocks with nulls + for (int i = 0; i < intermediateBlockOffset; i++) { + intermediateBlocks[i] = driverContext().blockFactory().newConstantNullBlock(positionCount); + } + for (int i = intermediateBlockOffset + intermediateStates; i < intermediateBlocks.length; i++) { + intermediateBlocks[i] = driverContext().blockFactory().newConstantNullBlock(positionCount); + } + } + Object result; - try (var aggregator = new Aggregator(aggregatorFunctionSupplier.aggregator(driverContext()), AggregatorMode.SINGLE)) { - Page inputPage = rows(testCase.getMultiRowDataValues()); + // Intermediate states to final result + try ( + var aggregator = aggregator( + expression, + intermediaryInputChannels(intermediateStates, intermediateBlockOffset), + AggregatorMode.FINAL + ) + ) { + Page inputPage = new Page(intermediateBlocks); try { aggregator.processPage(inputPage); } finally { inputPage.releaseBlocks(); } - // ElementType from DataType result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType())); } @@ -124,11 +233,7 @@ private void evaluate(Expression evaluableExpression) { } } - private void resolveExpression( - Expression expression, - Consumer onAggregator, - Consumer onEvaluableExpression - ) { + private void resolveExpression(Expression expression, Consumer onAggregator, Consumer onEvaluableExpression) { logger.info( "Test Values: " + testCase.getData().stream().map(TestCaseSupplier.TypedData::toString).collect(Collectors.joining(",")) ); @@ -154,8 +259,7 @@ private void resolveExpression( assertThat(expression, instanceOf(ToAggregator.class)); logger.info("Result type: " + expression.dataType()); - var inputChannels = inputChannels(); - onAggregator.accept(((ToAggregator) expression).supplier(inputChannels)); + onAggregator.accept(expression); } private Object extractResultFromAggregator(Aggregator aggregator, ElementType expectedElementType) { @@ -167,7 +271,8 @@ private Object extractResultFromAggregator(Aggregator aggregator, ElementType ex var block = blocks[resultBlockIndex]; - assertThat(block.elementType(), equalTo(expectedElementType)); + // For null blocks, the element type is NULL, so if the provided matcher matches, the type works too + assertThat(block.elementType(), is(oneOf(expectedElementType, ElementType.NULL))); return toJavaObject(blocks[resultBlockIndex], 0); } finally { @@ -175,10 +280,14 @@ private Object extractResultFromAggregator(Aggregator aggregator, ElementType ex } } - private List inputChannels() { + private List initialInputChannels() { // TODO: Randomize channels // TODO: If surrogated, channels may change - return IntStream.range(0, testCase.getMultiRowDataValues().size()).boxed().toList(); + return IntStream.range(0, testCase.getMultiRowFields().size()).boxed().toList(); + } + + private List intermediaryInputChannels(int intermediaryStates, int offset) { + return IntStream.range(offset, offset + intermediaryStates).boxed().toList(); } /** @@ -210,4 +319,10 @@ private Expression resolveSurrogates(Expression expression) { return expression; } + + private Aggregator aggregator(Expression expression, List inputChannels, AggregatorMode mode) { + AggregatorFunctionSupplier aggregatorFunctionSupplier = ((ToAggregator) expression).supplier(inputChannels); + + return new Aggregator(aggregatorFunctionSupplier.aggregator(driverContext()), mode); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index dc650e3fcd965..f8a5d997f4c54 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.optimizer.FoldNull; import org.elasticsearch.xpack.esql.parser.ExpressionBuilder; import org.elasticsearch.xpack.esql.planner.Layout; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.versionfield.Version; import org.junit.After; import org.junit.AfterClass; @@ -214,24 +215,40 @@ protected final Page row(List values) { } /** - * Creates a page based on a list of lists, where each list represents a column. + * Creates a page based on a list of multi-row fields. */ - protected final Page rows(List> values) { - if (values.isEmpty()) { + protected final Page rows(List multirowFields) { + if (multirowFields.isEmpty()) { return new Page(0, BlockUtils.NO_BLOCKS); } - var rowsCount = values.get(0).size(); + var rowsCount = multirowFields.get(0).multiRowData().size(); - values.stream().skip(1).forEach(l -> assertThat("All multi-row fields must have the same number of rows", l, hasSize(rowsCount))); + multirowFields.stream() + .skip(1) + .forEach( + field -> assertThat("All multi-row fields must have the same number of rows", field.multiRowData(), hasSize(rowsCount)) + ); - var rows = new ArrayList>(); - for (int i = 0; i < rowsCount; i++) { - final int index = i; - rows.add(values.stream().map(l -> l.get(index)).toList()); - } + var blocks = new Block[multirowFields.size()]; - var blocks = BlockUtils.fromList(TestBlockFactory.getNonBreakingInstance(), rows); + for (int i = 0; i < multirowFields.size(); i++) { + var field = multirowFields.get(i); + try ( + var wrapper = BlockUtils.wrapperFor( + TestBlockFactory.getNonBreakingInstance(), + PlannerUtils.toElementType(field.type()), + rowsCount + ) + ) { + + for (var row : field.multiRowData()) { + wrapper.accept(row); + } + + blocks[i] = wrapper.builder().build(); + } + } return new Page(rowsCount, blocks); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index 9095f5da63bf3..77c45bbd69854 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -1301,8 +1301,8 @@ public List getDataValues() { return data.stream().filter(d -> d.forceLiteral == false).map(TypedData::data).collect(Collectors.toList()); } - public List> getMultiRowDataValues() { - return data.stream().filter(TypedData::isMultiRow).map(TypedData::multiRowData).collect(Collectors.toList()); + public List getMultiRowFields() { + return data.stream().filter(TypedData::isMultiRow).collect(Collectors.toList()); } public boolean canGetDataAsLiterals() {