Skip to content

Commit

Permalink
Test aggregation metadata types matches function implementation
Browse files Browse the repository at this point in the history
Add tests to ensure input, intermediate, and output types declared in
aggregation metadata match function implementation
  • Loading branch information
dain committed Oct 12, 2021
1 parent 41fcd1e commit c9e8aa6
Show file tree
Hide file tree
Showing 20 changed files with 499 additions and 268 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,28 @@ public String name()
return name;
}

@VisibleForTesting
public List<Type> getParameterTypes()
{
return parameterTypes;
}

@VisibleForTesting
public Type getFinalType()
{
return finalType;
}

public Type getIntermediateType()
@VisibleForTesting
public Optional<Type> getIntermediateType()
{
if (intermediateType.size() == 1) {
return getOnlyElement(intermediateType);
if (intermediateType.isEmpty()) {
return Optional.empty();
}
else {
return RowType.anonymous(intermediateType);
if (intermediateType.size() == 1) {
return Optional.of(getOnlyElement(intermediateType));
}
return Optional.of(RowType.anonymous(intermediateType));
}

public List<Class<?>> getLambdaInterfaces()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.aggregation.AbstractTestAggregationFunction;
import io.trino.operator.aggregation.InternalAggregationFunction;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.Type;
Expand All @@ -26,7 +26,6 @@
import java.util.List;

import static io.trino.block.BlockAssertions.createBlockOfReals;
import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.operator.aggregation.AggregationTestUtils.assertAggregation;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
Expand All @@ -36,35 +35,40 @@
public class TestRealAverageAggregation
extends AbstractTestAggregationFunction
{
private InternalAggregationFunction avgFunction;
private ResolvedFunction avgFunction;

@BeforeClass
public void setUp()
{
avgFunction = createTestMetadataManager().getAggregateFunctionImplementation(
metadata.resolveFunction(QualifiedName.of("avg"), fromTypes(REAL)));
avgFunction = metadata.resolveFunction(QualifiedName.of("avg"), fromTypes(REAL));
}

@Test
public void averageOfNullIsNull()
{
assertAggregation(avgFunction,
assertAggregation(
metadata,
avgFunction,
null,
createBlockOfReals(null, null));
}

@Test
public void averageOfSingleValueEqualsThatValue()
{
assertAggregation(avgFunction,
assertAggregation(
metadata,
avgFunction,
1.23f,
createBlockOfReals(1.23f));
}

@Test
public void averageOfTwoMaxFloatsEqualsMaxFloat()
{
assertAggregation(avgFunction,
assertAggregation(
metadata,
avgFunction,
Float.MAX_VALUE,
createBlockOfReals(Float.MAX_VALUE, Float.MAX_VALUE));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ public final void destroyTestAggregationFunction()

protected abstract Block[] getSequenceBlocks(int start, int length);

protected final InternalAggregationFunction getFunction()
protected final ResolvedFunction getFunction()
{
ResolvedFunction resolvedFunction = metadata.resolveFunction(
return metadata.resolveFunction(
QualifiedName.of(getFunctionName()),
fromTypes(getFunctionParameterTypes()));
return metadata.getAggregateFunctionImplementation(resolvedFunction);
}

protected abstract String getFunctionName();
Expand Down Expand Up @@ -102,7 +101,7 @@ public void testMultiplePositions()
public void testAllPositionsNull()
{
// if there are no parameters skip this test
List<Type> parameterTypes = getFunction().getParameterTypes();
List<Type> parameterTypes = getFunction().getSignature().getArgumentTypes();
if (parameterTypes.isEmpty()) {
return;
}
Expand All @@ -118,7 +117,7 @@ public void testAllPositionsNull()
public void testMixedNullAndNonNullPositions()
{
// if there are no parameters skip this test
List<Type> parameterTypes = getFunction().getParameterTypes();
List<Type> parameterTypes = getFunction().getSignature().getArgumentTypes();
if (parameterTypes.isEmpty()) {
return;
}
Expand Down Expand Up @@ -154,7 +153,7 @@ public void testSlidingWindow()
}
Page inputPage = new Page(totalPositions, getSequenceBlocks(0, totalPositions));

InternalAggregationFunction function = getFunction();
InternalAggregationFunction function = metadata.getAggregateFunctionImplementation(getFunction());
List<Integer> channels = Ints.asList(createArgs(function));
AccumulatorFactory accumulatorFactory = function.bind(channels, Optional.empty());
PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(function.getParameterTypes(), totalPositions);
Expand Down Expand Up @@ -213,7 +212,7 @@ protected static Block[] createAlternatingNullsBlock(List<Type> types, Block...

protected void testAggregation(Object expectedValue, Block... blocks)
{
assertAggregation(getFunction(), expectedValue, blocks);
assertAggregation(metadata, getFunction(), expectedValue, blocks);
}

protected void assertInvalidAggregation(Runnable runnable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@

import com.google.common.primitives.Ints;
import io.trino.block.BlockAssertions;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.GroupByIdBlock;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import org.apache.commons.math3.util.Precision;

import java.util.Collections;
Expand All @@ -39,16 +43,16 @@ public final class AggregationTestUtils
{
private AggregationTestUtils() {}

public static void assertAggregation(InternalAggregationFunction function, Object expectedValue, Block... blocks)
public static void assertAggregation(Metadata metadata, ResolvedFunction resolvedFunction, Object expectedValue, Block... blocks)
{
assertAggregation(function, expectedValue, new Page(blocks));
assertAggregation(metadata, resolvedFunction, expectedValue, new Page(blocks));
}

public static void assertAggregation(InternalAggregationFunction function, Object expectedValue, Page page)
public static void assertAggregation(Metadata metadata, ResolvedFunction resolvedFunction, Object expectedValue, Page page)
{
BiFunction<Object, Object, Boolean> equalAssertion = makeValidityAssertion(expectedValue);

assertAggregation(function, equalAssertion, null, page, expectedValue);
assertAggregation(metadata, resolvedFunction, equalAssertion, null, page, expectedValue);
}

public static BiFunction<Object, Object, Boolean> makeValidityAssertion(Object expectedValue)
Expand All @@ -62,8 +66,15 @@ public static BiFunction<Object, Object, Boolean> makeValidityAssertion(Object e
return Objects::equals;
}

public static void assertAggregation(InternalAggregationFunction function, BiFunction<Object, Object, Boolean> equalAssertion, String testDescription, Page page, Object expectedValue)
public static void assertAggregation(Metadata metadata, ResolvedFunction resolvedFunction, BiFunction<Object, Object, Boolean> equalAssertion, String testDescription, Page page, Object expectedValue)
{
AggregationFunctionMetadata functionMetadata = metadata.getAggregationFunctionMetadata(resolvedFunction);
InternalAggregationFunction function = metadata.getAggregateFunctionImplementation(resolvedFunction);

assertEquals(function.getParameterTypes(), resolvedFunction.getSignature().getArgumentTypes());
assertEquals(function.getFinalType(), resolvedFunction.getSignature().getReturnType());
assertEquals(function.getIntermediateType().map(Type::getTypeSignature), functionMetadata.getIntermediateType());

int positions = page.getPositionCount();
for (int i = 1; i < page.getChannelCount(); i++) {
assertEquals(positions, page.getBlock(i).getPositionCount(), "input blocks provided are not equal in position count");
Expand Down
Loading

0 comments on commit c9e8aa6

Please sign in to comment.