Skip to content

Commit

Permalink
ESQL: Add tests to call aggregation intermediate states (elastic#110279)
Browse files Browse the repository at this point in the history
Test aggregations intermediate states on base aggregation test class.

Added another "middleware" to add "no rows" test cases.
  • Loading branch information
ivancea authored Jul 4, 2024
1 parent f3c811c commit 8b7d833
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.planner.ToAggregator;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.oneOf;

/**
* Base class for aggregation tests.
Expand All @@ -47,7 +53,43 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
*/
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecks(List<TestCaseSupplier> suppliers) {
// TODO: Add case with no input expecting null
return parameterSuppliersFromTypedData(randomizeBytesRefsOffset(suppliers));
return parameterSuppliersFromTypedData(withNoRowsExpectingNull(randomizeBytesRefsOffset(suppliers)));
}

/**
* Adds a test case with no rows, expecting null, to the list of suppliers.
*/
protected static List<TestCaseSupplier> withNoRowsExpectingNull(List<TestCaseSupplier> suppliers) {
List<TestCaseSupplier> newSuppliers = new ArrayList<>(suppliers);
Set<List<DataType>> uniqueSignatures = new HashSet<>();

for (TestCaseSupplier original : suppliers) {
if (uniqueSignatures.add(original.types())) {
newSuppliers.add(new TestCaseSupplier(original.name() + " with no rows", original.types(), () -> {
var testCase = original.get();

if (testCase.getData().stream().noneMatch(TestCaseSupplier.TypedData::isMultiRow)) {
// Fail if no multi-row data, at least until a real case is found
fail("No multi-row data found in test case: " + testCase);
}

var newData = testCase.getData().stream().map(td -> td.isMultiRow() ? td.withData(List.of()) : td).toList();

return new TestCaseSupplier.TestCase(
newData,
testCase.evaluatorToString(),
testCase.expectedType(),
nullValue(),
null,
testCase.getExpectedTypeError(),
null,
null
);
}));
}
}

return newSuppliers;
}

public void testAggregate() {
Expand All @@ -56,6 +98,12 @@ public void testAggregate() {
resolveExpression(expression, this::aggregateSingleMode, this::evaluate);
}

public void testAggregateIntermediate() {
Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase);

resolveExpression(expression, this::aggregateWithIntermediates, this::evaluate);
}

public void testFold() {
Expression expression = buildLiteralExpression(testCase);

Expand All @@ -80,17 +128,78 @@ public void testFold() {
});
}

private void aggregateSingleMode(AggregatorFunctionSupplier aggregatorFunctionSupplier) {
private void aggregateSingleMode(Expression expression) {
Object result;
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
Page inputPage = rows(testCase.getMultiRowFields());
try {
aggregator.processPage(inputPage);
} finally {
inputPage.releaseBlocks();
}

result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType()));
}

assertThat(result, not(equalTo(Double.NaN)));
assert testCase.getMatcher().matches(Double.POSITIVE_INFINITY) == false;
assertThat(result, not(equalTo(Double.POSITIVE_INFINITY)));
assert testCase.getMatcher().matches(Double.NEGATIVE_INFINITY) == false;
assertThat(result, not(equalTo(Double.NEGATIVE_INFINITY)));
assertThat(result, testCase.getMatcher());
if (testCase.getExpectedWarnings() != null) {
assertWarnings(testCase.getExpectedWarnings());
}
}

private void aggregateWithIntermediates(Expression expression) {
int intermediateBlockOffset = randomIntBetween(0, 10);
Block[] intermediateBlocks;
int intermediateStates;

// Input rows to intermediate states
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.INITIAL)) {
intermediateStates = aggregator.evaluateBlockCount();

int intermediateBlockExtraSize = randomIntBetween(0, 10);
intermediateBlocks = new Block[intermediateBlockOffset + intermediateStates + intermediateBlockExtraSize];

Page inputPage = rows(testCase.getMultiRowFields());
try {
aggregator.processPage(inputPage);
} finally {
inputPage.releaseBlocks();
}

aggregator.evaluate(intermediateBlocks, intermediateBlockOffset, driverContext());

int positionCount = intermediateBlocks[intermediateBlockOffset].getPositionCount();

// Fill offset and extra blocks with nulls
for (int i = 0; i < intermediateBlockOffset; i++) {
intermediateBlocks[i] = driverContext().blockFactory().newConstantNullBlock(positionCount);
}
for (int i = intermediateBlockOffset + intermediateStates; i < intermediateBlocks.length; i++) {
intermediateBlocks[i] = driverContext().blockFactory().newConstantNullBlock(positionCount);
}
}

Object result;
try (var aggregator = new Aggregator(aggregatorFunctionSupplier.aggregator(driverContext()), AggregatorMode.SINGLE)) {
Page inputPage = rows(testCase.getMultiRowDataValues());
// Intermediate states to final result
try (
var aggregator = aggregator(
expression,
intermediaryInputChannels(intermediateStates, intermediateBlockOffset),
AggregatorMode.FINAL
)
) {
Page inputPage = new Page(intermediateBlocks);
try {
aggregator.processPage(inputPage);
} finally {
inputPage.releaseBlocks();
}

// ElementType from DataType
result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType()));
}

Expand Down Expand Up @@ -124,11 +233,7 @@ private void evaluate(Expression evaluableExpression) {
}
}

private void resolveExpression(
Expression expression,
Consumer<AggregatorFunctionSupplier> onAggregator,
Consumer<Expression> onEvaluableExpression
) {
private void resolveExpression(Expression expression, Consumer<Expression> onAggregator, Consumer<Expression> onEvaluableExpression) {
logger.info(
"Test Values: " + testCase.getData().stream().map(TestCaseSupplier.TypedData::toString).collect(Collectors.joining(","))
);
Expand All @@ -154,8 +259,7 @@ private void resolveExpression(
assertThat(expression, instanceOf(ToAggregator.class));
logger.info("Result type: " + expression.dataType());

var inputChannels = inputChannels();
onAggregator.accept(((ToAggregator) expression).supplier(inputChannels));
onAggregator.accept(expression);
}

private Object extractResultFromAggregator(Aggregator aggregator, ElementType expectedElementType) {
Expand All @@ -167,18 +271,23 @@ private Object extractResultFromAggregator(Aggregator aggregator, ElementType ex

var block = blocks[resultBlockIndex];

assertThat(block.elementType(), equalTo(expectedElementType));
// For null blocks, the element type is NULL, so if the provided matcher matches, the type works too
assertThat(block.elementType(), is(oneOf(expectedElementType, ElementType.NULL)));

return toJavaObject(blocks[resultBlockIndex], 0);
} finally {
Releasables.close(blocks);
}
}

private List<Integer> inputChannels() {
private List<Integer> initialInputChannels() {
// TODO: Randomize channels
// TODO: If surrogated, channels may change
return IntStream.range(0, testCase.getMultiRowDataValues().size()).boxed().toList();
return IntStream.range(0, testCase.getMultiRowFields().size()).boxed().toList();
}

private List<Integer> intermediaryInputChannels(int intermediaryStates, int offset) {
return IntStream.range(offset, offset + intermediaryStates).boxed().toList();
}

/**
Expand Down Expand Up @@ -210,4 +319,10 @@ private Expression resolveSurrogates(Expression expression) {

return expression;
}

private Aggregator aggregator(Expression expression, List<Integer> inputChannels, AggregatorMode mode) {
AggregatorFunctionSupplier aggregatorFunctionSupplier = ((ToAggregator) expression).supplier(inputChannels);

return new Aggregator(aggregatorFunctionSupplier.aggregator(driverContext()), mode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.xpack.esql.optimizer.FoldNull;
import org.elasticsearch.xpack.esql.parser.ExpressionBuilder;
import org.elasticsearch.xpack.esql.planner.Layout;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.versionfield.Version;
import org.junit.After;
import org.junit.AfterClass;
Expand Down Expand Up @@ -214,24 +215,40 @@ protected final Page row(List<Object> values) {
}

/**
* Creates a page based on a list of lists, where each list represents a column.
* Creates a page based on a list of multi-row fields.
*/
protected final Page rows(List<List<Object>> values) {
if (values.isEmpty()) {
protected final Page rows(List<TestCaseSupplier.TypedData> multirowFields) {
if (multirowFields.isEmpty()) {
return new Page(0, BlockUtils.NO_BLOCKS);
}

var rowsCount = values.get(0).size();
var rowsCount = multirowFields.get(0).multiRowData().size();

values.stream().skip(1).forEach(l -> assertThat("All multi-row fields must have the same number of rows", l, hasSize(rowsCount)));
multirowFields.stream()
.skip(1)
.forEach(
field -> assertThat("All multi-row fields must have the same number of rows", field.multiRowData(), hasSize(rowsCount))
);

var rows = new ArrayList<List<Object>>();
for (int i = 0; i < rowsCount; i++) {
final int index = i;
rows.add(values.stream().map(l -> l.get(index)).toList());
}
var blocks = new Block[multirowFields.size()];

var blocks = BlockUtils.fromList(TestBlockFactory.getNonBreakingInstance(), rows);
for (int i = 0; i < multirowFields.size(); i++) {
var field = multirowFields.get(i);
try (
var wrapper = BlockUtils.wrapperFor(
TestBlockFactory.getNonBreakingInstance(),
PlannerUtils.toElementType(field.type()),
rowsCount
)
) {

for (var row : field.multiRowData()) {
wrapper.accept(row);
}

blocks[i] = wrapper.builder().build();
}
}

return new Page(rowsCount, blocks);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1301,8 +1301,8 @@ public List<Object> getDataValues() {
return data.stream().filter(d -> d.forceLiteral == false).map(TypedData::data).collect(Collectors.toList());
}

public List<List<Object>> getMultiRowDataValues() {
return data.stream().filter(TypedData::isMultiRow).map(TypedData::multiRowData).collect(Collectors.toList());
public List<TypedData> getMultiRowFields() {
return data.stream().filter(TypedData::isMultiRow).collect(Collectors.toList());
}

public boolean canGetDataAsLiterals() {
Expand Down

0 comments on commit 8b7d833

Please sign in to comment.