From d406ff22bc05b3c38d49454f029a763df96a2cb8 Mon Sep 17 00:00:00 2001 From: Alan Post Date: Mon, 30 Sep 2019 12:46:27 +0200 Subject: [PATCH 1/3] Add test for Accumulator.addInput from WindowIndex Add a test for the WindowNode use case for Accumulators. Fix a few bugs in the aggregation tests uncovered by the additional coverage. The new tests didn't uncover any product bugs. --- .../AbstractTestAggregationFunction.java | 63 +++++++++++++++++++ .../TestBooleanAndAggregation.java | 2 +- .../aggregation/TestBooleanOrAggregation.java | 2 +- .../TestVarBinaryMaxAggregation.java | 4 +- .../TestVarBinaryMinAggregation.java | 4 +- 5 files changed, 69 insertions(+), 6 deletions(-) diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/AbstractTestAggregationFunction.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/AbstractTestAggregationFunction.java index 81a8ddc6e7e3..698f5a4f6428 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/AbstractTestAggregationFunction.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/AbstractTestAggregationFunction.java @@ -13,11 +13,17 @@ */ package io.prestosql.operator.aggregation; +import com.google.common.primitives.Ints; +import io.prestosql.block.BlockAssertions; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; +import io.prestosql.operator.PagesIndex; +import io.prestosql.operator.window.PagesWindowIndex; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.block.RunLengthEncodedBlock; +import io.prestosql.spi.function.WindowIndex; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.analyzer.TypeSignatureProvider; @@ -27,10 +33,14 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; import static io.prestosql.operator.aggregation.AggregationTestUtils.assertAggregation; +import static io.prestosql.operator.aggregation.AggregationTestUtils.createArgs; +import static io.prestosql.operator.aggregation.AggregationTestUtils.getFinalBlock; +import static io.prestosql.operator.aggregation.AggregationTestUtils.makeValidityAssertion; import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.prestosql.testing.assertions.PrestoExceptionAssert.assertPrestoExceptionThrownBy; @@ -131,6 +141,59 @@ public void testPositiveOnlyValues() testAggregation(getExpectedValue(2, 4), getSequenceBlocks(2, 4)); } + @Test + public void testSlidingWindow() + { + // Builds trailing windows of length 0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0 + int totalPositions = 12; + int[] windowWidths = new int[totalPositions]; + Object[] expectedValues = new Object[totalPositions]; + + for (int i = 0; i < totalPositions; ++i) { + int windowWidth = Integer.min(i, totalPositions - 1 - i); + windowWidths[i] = windowWidth; + expectedValues[i] = getExpectedValue(i, windowWidth); + } + Page inputPage = new Page(totalPositions, getSequenceBlocks(0, totalPositions)); + + InternalAggregationFunction function = getFunction(); + List channels = Ints.asList(createArgs(function)); + AccumulatorFactory accumulatorFactory = function.bind(channels, Optional.empty()); + PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(function.getParameterTypes(), totalPositions); + pagesIndex.addPage(inputPage); + WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1); + + Accumulator aggregation = accumulatorFactory.createAccumulator(); + int oldStart = 0; + int oldWidth = 0; + for (int start = 0; start < totalPositions; ++start) { + int width = windowWidths[start]; + // Note that add/removeInput's interval is inclusive on both ends + if (accumulatorFactory.hasRemoveInput()) { + for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { + if (oldi < start || oldi >= start + width) { + aggregation.removeInput(windowIndex, channels, oldi, oldi); + } + } + for (int newi = start; newi < start + width; ++newi) { + if (newi < oldStart || newi >= oldStart + oldWidth) { + aggregation.addInput(windowIndex, channels, newi, newi); + } + } + } + else { + aggregation = accumulatorFactory.createAccumulator(); + aggregation.addInput(windowIndex, channels, start, start + width - 1); + } + oldStart = start; + oldWidth = width; + Block block = getFinalBlock(aggregation); + makeValidityAssertion(expectedValues[start]).apply( + BlockAssertions.getOnlyValue(aggregation.getFinalType(), block), + expectedValues[start]); + } + } + protected static Block[] createAlternatingNullsBlock(List types, Block... sequenceBlocks) { Block[] alternatingNullsBlocks = new Block[sequenceBlocks.length]; diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanAndAggregation.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanAndAggregation.java index e90794822883..ef4b1f3903a5 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanAndAggregation.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanAndAggregation.java @@ -44,7 +44,7 @@ protected Boolean getExpectedValue(int start, int length) if (length == 0) { return null; } - return length > 1 ? FALSE : TRUE; + return (length > 1 || (start % 2 == 1)) ? FALSE : TRUE; } @Override diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanOrAggregation.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanOrAggregation.java index c576480e8676..ef02fb69c083 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanOrAggregation.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestBooleanOrAggregation.java @@ -44,7 +44,7 @@ protected Boolean getExpectedValue(int start, int length) if (length == 0) { return null; } - return length > 1 ? TRUE : FALSE; + return (length > 1 || (start % 2 == 1)) ? TRUE : FALSE; } @Override diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMaxAggregation.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMaxAggregation.java index 58d454070022..dec7133b78dd 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMaxAggregation.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMaxAggregation.java @@ -33,7 +33,7 @@ public class TestVarBinaryMaxAggregation protected Block[] getSequenceBlocks(int start, int length) { BlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, length); - for (int i = 0; i < length; i++) { + for (int i = start; i < start + length; i++) { VARBINARY.writeSlice(blockBuilder, Slices.wrappedBuffer(Ints.toByteArray(i))); } return new Block[] {blockBuilder.build()}; @@ -46,7 +46,7 @@ protected Object getExpectedValue(int start, int length) return null; } Slice max = null; - for (int i = 0; i < length; i++) { + for (int i = start; i < start + length; i++) { Slice slice = Slices.wrappedBuffer(Ints.toByteArray(i)); max = (max == null) ? slice : Ordering.natural().max(max, slice); } diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMinAggregation.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMinAggregation.java index bca938cb7300..71ee56ae1441 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMinAggregation.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/TestVarBinaryMinAggregation.java @@ -33,7 +33,7 @@ public class TestVarBinaryMinAggregation protected Block[] getSequenceBlocks(int start, int length) { BlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, length); - for (int i = 0; i < length; i++) { + for (int i = start; i < start + length; i++) { VARBINARY.writeSlice(blockBuilder, Slices.wrappedBuffer(Ints.toByteArray(i))); } return new Block[] {blockBuilder.build()}; @@ -46,7 +46,7 @@ protected Object getExpectedValue(int start, int length) return null; } Slice min = null; - for (int i = 0; i < length; i++) { + for (int i = start; i < start + length; i++) { Slice slice = Slices.wrappedBuffer(Ints.toByteArray(i)); min = (min == null) ? slice : Ordering.natural().min(min, slice); } From 78ff71cd88882c4a8c3f0fad1b6daf5dff0c275d Mon Sep 17 00:00:00 2001 From: Alan Post Date: Mon, 30 Sep 2019 12:46:28 +0200 Subject: [PATCH 2/3] Roll window aggregations instead of recomputing Add a removeInput() function to some Accumulators, and when it exists, use it in aggregate window functions to roll the aggregation forward incrementally. Dramatically speeds up queries such as: SELECT COUNT(quantity) OVER (ROWS BETWEEN 2000 PRECEDING AND 2000 FOLLOWING) Extracted from: https://github.com/prestodb/presto/pull/8974 --- .../AbstractMinMaxAggregationFunction.java | 2 + .../AbstractMinMaxNAggregationFunction.java | 2 + .../operator/aggregation/Accumulator.java | 2 + .../aggregation/AccumulatorCompiler.java | 25 +++++-- .../aggregation/AccumulatorFactory.java | 2 + .../AggregationFromAnnotationsParser.java | 17 ++++- .../AggregationImplementation.java | 34 +++++++++- .../aggregation/AggregationMetadata.java | 11 ++++ .../ArbitraryAggregationFunction.java | 2 + .../aggregation/AverageAggregations.java | 15 +++++ .../ChecksumAggregationFunction.java | 2 + .../aggregation/CountAggregation.java | 7 ++ .../operator/aggregation/CountColumn.java | 8 +++ .../aggregation/CountIfAggregation.java | 9 +++ .../DecimalAverageAggregation.java | 2 + .../aggregation/DecimalSumAggregation.java | 2 + .../GenericAccumulatorFactory.java | 21 ++++++ .../GenericAccumulatorFactoryBinder.java | 5 ++ .../aggregation/MapAggregationFunction.java | 2 + .../aggregation/MapUnionAggregation.java | 2 + .../MergeQuantileDigestFunction.java | 2 + .../aggregation/ParametricAggregation.java | 3 + .../QuantileDigestAggregationFunction.java | 2 + .../aggregation/RealAverageAggregation.java | 9 +++ .../ReduceAggregationFunction.java | 2 + .../arrayagg/ArrayAggregationFunction.java | 2 + .../aggregation/histogram/Histogram.java | 2 + .../minmaxby/AbstractMinMaxBy.java | 2 + .../AbstractMinMaxByNAggregationFunction.java | 2 + .../MultimapAggregationFunction.java | 2 + .../window/AggregateWindowFunction.java | 55 ++++++++++++++-- .../aggregation/AggregationTestUtils.java | 20 +++--- .../window/TestAggregateWindowFunction.java | 65 +++++++++++++++++++ .../spi/function/RemoveInputFunction.java | 26 ++++++++ 34 files changed, 344 insertions(+), 22 deletions(-) create mode 100644 presto-spi/src/main/java/io/prestosql/spi/function/RemoveInputFunction.java diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java index 13f01640fc88..47dd534d10b6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java @@ -36,6 +36,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.internalOperator; @@ -144,6 +145,7 @@ else if (type.getJavaType() == boolean.class) { generateAggregationName(getSignature().getName(), type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createParameterMetadata(type), inputFunction, + Optional.empty(), combineFunction, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxNAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxNAggregationFunction.java index 44ce2a44131d..812c3d0cc074 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxNAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxNAggregationFunction.java @@ -31,6 +31,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import java.util.function.Function; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -97,6 +98,7 @@ protected InternalAggregationFunction generateAggregation(Type type) generateAggregationName(getSignature().getName(), type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, INPUT_FUNCTION.bindTo(comparator).bindTo(type), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION.bindTo(outputType), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/Accumulator.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/Accumulator.java index bbe1c18f4e98..bb27d2fd63c9 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/Accumulator.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/Accumulator.java @@ -33,6 +33,8 @@ public interface Accumulator void addInput(WindowIndex index, List channels, int startPosition, int endPosition); + void removeInput(WindowIndex index, List channels, int startPosition, int endPosition); + void addIntermediate(Block block); void evaluateIntermediate(BlockBuilder blockBuilder); diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorCompiler.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorCompiler.java index ad4de9ee4fd1..1d11ce312761 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorCompiler.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorCompiler.java @@ -96,6 +96,7 @@ public static GenericAccumulatorFactoryBinder generateAccumulatorFactoryBinder(A return new GenericAccumulatorFactoryBinder( metadata.getAccumulatorStateDescriptors(), accumulatorClass, + metadata.getRemoveInputFunction().isPresent(), groupedAccumulatorClass); } @@ -157,14 +158,25 @@ private static Class generateAccumulatorClass( metadata.getInputFunction(), callSiteBinder, grouped); - generateAddInputWindowIndex( + generateAddOrRemoveInputWindowIndex( definition, stateFields, metadata.getValueInputMetadata(), metadata.getLambdaInterfaces(), lambdaProviderFields, metadata.getInputFunction(), + "addInput", callSiteBinder); + metadata.getRemoveInputFunction().ifPresent( + removeInputFunction -> generateAddOrRemoveInputWindowIndex( + definition, + stateFields, + metadata.getValueInputMetadata(), + metadata.getLambdaInterfaces(), + lambdaProviderFields, + removeInputFunction, + "removeInput", + callSiteBinder)); generateGetEstimatedSize(definition, stateFields); generateGetIntermediateType( @@ -317,13 +329,14 @@ private static void generateAddInput( body.ret(); } - private static void generateAddInputWindowIndex( + private static void generateAddOrRemoveInputWindowIndex( ClassDefinition definition, List stateField, List parameterMetadatas, List> lambdaInterfaces, List lambdaProviderFields, MethodHandle inputFunction, + String generatedFunctionName, CallSiteBinder callSiteBinder) { // TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions. @@ -333,7 +346,11 @@ private static void generateAddInputWindowIndex( Parameter startPosition = arg("startPosition", int.class); Parameter endPosition = arg("endPosition", int.class); - MethodDefinition method = definition.declareMethod(a(PUBLIC), "addInput", type(void.class), ImmutableList.of(index, channels, startPosition, endPosition)); + MethodDefinition method = definition.declareMethod( + a(PUBLIC), + generatedFunctionName, + type(void.class), + ImmutableList.of(index, channels, startPosition, endPosition)); Scope scope = method.getScope(); Variable position = scope.declareVariable(int.class, "position"); @@ -342,7 +359,7 @@ private static void generateAddInputWindowIndex( BytecodeExpression invokeInputFunction = invokeDynamic( BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), - "input", + generatedFunctionName, binding.getType(), getInvokeFunctionOnWindowIndexParameters( scope, diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorFactory.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorFactory.java index 8ab907c9c560..ea85f7f43a81 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorFactory.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AccumulatorFactory.java @@ -19,6 +19,8 @@ public interface AccumulatorFactory { List getInputChannels(); + boolean hasRemoveInput(); + Accumulator createAccumulator(); Accumulator createIntermediateAccumulator(); diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationFromAnnotationsParser.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationFromAnnotationsParser.java index 2afa4161faf9..197dddfeb02a 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationFromAnnotationsParser.java @@ -16,6 +16,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.MoreCollectors; import io.prestosql.operator.ParametricImplementationsGroup; import io.prestosql.operator.annotations.FunctionsParserHelper; import io.prestosql.spi.function.AccumulatorState; @@ -24,6 +25,7 @@ import io.prestosql.spi.function.CombineFunction; import io.prestosql.spi.function.InputFunction; import io.prestosql.spi.function.OutputFunction; +import io.prestosql.spi.function.RemoveInputFunction; import io.prestosql.spi.type.TypeSignature; import javax.annotation.Nullable; @@ -75,8 +77,9 @@ public static List parseFunctionDefinitions(Class aggr Optional aggregationStateSerializerFactory = getAggregationStateSerializerFactory(aggregationDefinition, stateClass); for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) { for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { + Optional removeInputFunction = getRemoveInputFunction(aggregationDefinition, inputFunction); for (AggregationHeader header : parseHeaders(aggregationDefinition, outputFunction)) { - AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); + AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, removeInputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); ParametricImplementationsGroup implementations = ParametricImplementationsGroup.of(onlyImplementation); builder.add(new ParametricAggregation(implementations.getSignature(), header, implementations)); } @@ -97,7 +100,8 @@ public static ParametricAggregation parseFunctionDefinition(Class aggregation Optional aggregationStateSerializerFactory = getAggregationStateSerializerFactory(aggregationDefinition, stateClass); Method outputFunction = getOnlyElement(getOutputFunctions(aggregationDefinition, stateClass)); for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { - AggregationImplementation implementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); + Optional removeInputFunction = getRemoveInputFunction(aggregationDefinition, inputFunction); + AggregationImplementation implementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, removeInputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); implementationsBuilder.addImplementation(implementation); } } @@ -204,6 +208,15 @@ private static List getInputFunctions(Class clazz, Class stateClas return inputFunctions; } + private static Optional getRemoveInputFunction(Class clazz, Method inputFunction) + { + // Only include methods which take the same parameters as the corresponding input function + return FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, RemoveInputFunction.class).stream() + .filter(method -> Arrays.equals(method.getParameterTypes(), inputFunction.getParameterTypes())) + .filter(method -> Arrays.deepEquals(method.getParameterAnnotations(), inputFunction.getParameterAnnotations())) + .collect(MoreCollectors.toOptional()); + } + private static Set> getStateClasses(Class clazz) { ImmutableSet.Builder> builder = ImmutableSet.builder(); diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationImplementation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationImplementation.java index 865a558a62c2..b29e486e8c9d 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationImplementation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationImplementation.java @@ -33,6 +33,7 @@ import io.prestosql.spi.function.SqlType; import io.prestosql.spi.function.TypeParameter; import io.prestosql.spi.type.TypeSignature; +import io.prestosql.util.Reflection; import java.lang.annotation.Annotation; import java.lang.invoke.MethodHandle; @@ -92,11 +93,13 @@ public boolean isBlockPosition() private final Class definitionClass; private final Class stateClass; private final MethodHandle inputFunction; + private final Optional removeInputFunction; private final MethodHandle outputFunction; private final MethodHandle combineFunction; private final Optional stateSerializerFactory; private final List argumentNativeContainerTypes; private final List inputDependencies; + private final List removeInputDependencies; private final List combineDependencies; private final List outputDependencies; private final List stateSerializerFactoryDependencies; @@ -107,11 +110,13 @@ public AggregationImplementation( Class definitionClass, Class stateClass, MethodHandle inputFunction, + Optional removeInputFunction, MethodHandle outputFunction, MethodHandle combineFunction, Optional stateSerializerFactory, List argumentNativeContainerTypes, List inputDependencies, + List removeInputDependencies, List combineDependencies, List outputDependencies, List stateSerializerFactoryDependencies, @@ -121,11 +126,13 @@ public AggregationImplementation( this.definitionClass = requireNonNull(definitionClass, "definition class cannot be null"); this.stateClass = requireNonNull(stateClass, "stateClass cannot be null"); this.inputFunction = requireNonNull(inputFunction, "inputFunction cannot be null"); + this.removeInputFunction = requireNonNull(removeInputFunction, "removeInputFunction cannot be null"); this.outputFunction = requireNonNull(outputFunction, "outputFunction cannot be null"); this.combineFunction = requireNonNull(combineFunction, "combineFunction cannot be null"); this.stateSerializerFactory = requireNonNull(stateSerializerFactory, "stateSerializerFactory cannot be null"); this.argumentNativeContainerTypes = requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes cannot be null"); this.inputDependencies = requireNonNull(inputDependencies, "inputDependencies cannot be null"); + this.removeInputDependencies = requireNonNull(removeInputDependencies, "removeInputDependencies cannot be null"); this.outputDependencies = requireNonNull(outputDependencies, "outputDependencies cannot be null"); this.combineDependencies = requireNonNull(combineDependencies, "combineDependencies cannot be null"); this.stateSerializerFactoryDependencies = requireNonNull(stateSerializerFactoryDependencies, "stateSerializerFactoryDependencies cannot be null"); @@ -159,6 +166,11 @@ public MethodHandle getInputFunction() return inputFunction; } + public Optional getRemoveInputFunction() + { + return removeInputFunction; + } + public MethodHandle getOutputFunction() { return outputFunction; @@ -179,6 +191,11 @@ public List getInputDependencies() return inputDependencies; } + public List getRemoveInputDependencies() + { + return removeInputDependencies; + } + public List getOutputDependencies() { return outputDependencies; @@ -226,11 +243,13 @@ public static final class Parser private final Class aggregationDefinition; private final Class stateClass; private final MethodHandle inputHandle; + private final Optional removeInputHandle; private final MethodHandle outputHandle; private final MethodHandle combineHandle; private final Optional stateSerializerFactoryHandle; private final List argumentNativeContainerTypes; private final List inputDependencies; + private final List removeInputDependencies; private final List combineDependencies; private final List outputDependencies; private final List stateSerializerFactoryDependencies; @@ -250,6 +269,7 @@ private Parser( AggregationHeader header, Class stateClass, Method inputFunction, + Optional removeInputFunction, Method outputFunction, Method combineFunction, Optional stateSerializerFactoryFunction) @@ -266,6 +286,7 @@ private Parser( // parse dependencies inputDependencies = parseImplementationDependencies(inputFunction); + removeInputDependencies = removeInputFunction.map(this::parseImplementationDependencies).orElse(ImmutableList.of()); outputDependencies = parseImplementationDependencies(outputFunction); combineDependencies = parseImplementationDependencies(combineFunction); stateSerializerFactoryDependencies = stateSerializerFactoryFunction.map(this::parseImplementationDependencies).orElse(ImmutableList.of()); @@ -275,7 +296,12 @@ private Parser( // parse constraints longVariableConstraints = FunctionsParserHelper.parseLongVariableConstraints(inputFunction); - List allDependencies = Stream.of(inputDependencies.stream(), outputDependencies.stream(), combineDependencies.stream()) + List allDependencies = + Stream.of( + inputDependencies.stream(), + removeInputDependencies.stream(), + outputDependencies.stream(), + combineDependencies.stream()) .reduce(Stream::concat) .orElseGet(Stream::empty) .collect(toImmutableList()); @@ -297,6 +323,7 @@ private Parser( } inputHandle = methodHandle(inputFunction); + removeInputHandle = removeInputFunction.map(Reflection::methodHandle); combineHandle = methodHandle(combineFunction); outputHandle = methodHandle(outputFunction); } @@ -316,11 +343,13 @@ private AggregationImplementation get() aggregationDefinition, stateClass, inputHandle, + removeInputHandle, outputHandle, combineHandle, stateSerializerFactoryHandle, argumentNativeContainerTypes, inputDependencies, + removeInputDependencies, combineDependencies, outputDependencies, stateSerializerFactoryDependencies, @@ -332,11 +361,12 @@ public static AggregationImplementation parseImplementation( AggregationHeader header, Class stateClass, Method inputFunction, + Optional removeInputFunction, Method outputFunction, Method combineFunction, Optional stateSerializerFactoryFunction) { - return new Parser(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, stateSerializerFactoryFunction).get(); + return new Parser(aggregationDefinition, header, stateClass, inputFunction, removeInputFunction, outputFunction, combineFunction, stateSerializerFactoryFunction).get(); } private static List parseParameterMetadataTypes(Method method) diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationMetadata.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationMetadata.java index 62185169c9f4..3e88b15fee84 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationMetadata.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AggregationMetadata.java @@ -25,6 +25,7 @@ import java.lang.invoke.MethodHandle; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -44,6 +45,7 @@ public class AggregationMetadata private final List valueInputMetadata; private final List> lambdaInterfaces; private final MethodHandle inputFunction; + private final Optional removeInputFunction; private final MethodHandle combineFunction; private final MethodHandle outputFunction; private final List accumulatorStateDescriptors; @@ -53,6 +55,7 @@ public AggregationMetadata( String name, List valueInputMetadata, MethodHandle inputFunction, + Optional removeInputFunction, MethodHandle combineFunction, MethodHandle outputFunction, List accumulatorStateDescriptors, @@ -62,6 +65,7 @@ public AggregationMetadata( name, valueInputMetadata, inputFunction, + removeInputFunction, combineFunction, outputFunction, accumulatorStateDescriptors, @@ -73,6 +77,7 @@ public AggregationMetadata( String name, List valueInputMetadata, MethodHandle inputFunction, + Optional removeInputFunction, MethodHandle combineFunction, MethodHandle outputFunction, List accumulatorStateDescriptors, @@ -83,6 +88,7 @@ public AggregationMetadata( this.valueInputMetadata = ImmutableList.copyOf(requireNonNull(valueInputMetadata, "valueInputMetadata is null")); this.name = requireNonNull(name, "name is null"); this.inputFunction = requireNonNull(inputFunction, "inputFunction is null"); + this.removeInputFunction = requireNonNull(removeInputFunction, "removeInputFunction is null"); this.combineFunction = requireNonNull(combineFunction, "combineFunction is null"); this.outputFunction = requireNonNull(outputFunction, "outputFunction is null"); this.accumulatorStateDescriptors = requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null"); @@ -118,6 +124,11 @@ public MethodHandle getInputFunction() return inputFunction; } + public Optional getRemoveInputFunction() + { + return removeInputFunction; + } + public MethodHandle getCombineFunction() { return combineFunction; diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/ArbitraryAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/ArbitraryAggregationFunction.java index 9de5a77f9a8b..b36eec70fb14 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/ArbitraryAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/ArbitraryAggregationFunction.java @@ -33,6 +33,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.typeVariable; @@ -136,6 +137,7 @@ else if (type.getJavaType() == boolean.class) { generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, + Optional.empty(), combineFunction, outputFunction.bindTo(type), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AverageAggregations.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AverageAggregations.java index afcaacfad0d3..82510f85b919 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AverageAggregations.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AverageAggregations.java @@ -20,6 +20,7 @@ import io.prestosql.spi.function.CombineFunction; import io.prestosql.spi.function.InputFunction; import io.prestosql.spi.function.OutputFunction; +import io.prestosql.spi.function.RemoveInputFunction; import io.prestosql.spi.function.SqlType; import io.prestosql.spi.type.StandardTypes; @@ -44,6 +45,20 @@ public static void input(@AggregationState LongAndDoubleState state, @SqlType(St state.setDouble(state.getDouble() + value); } + @RemoveInputFunction + public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.BIGINT) long value) + { + state.setLong(state.getLong() - 1); + state.setDouble(state.getDouble() - value); + } + + @RemoveInputFunction + public static void removeInput(@AggregationState LongAndDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) + { + state.setLong(state.getLong() - 1); + state.setDouble(state.getDouble() - value); + } + @CombineFunction public static void combine(@AggregationState LongAndDoubleState state, @AggregationState LongAndDoubleState otherState) { diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/ChecksumAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/ChecksumAggregationFunction.java index 5725a8adb589..af7ed67a6a7a 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/ChecksumAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/ChecksumAggregationFunction.java @@ -29,6 +29,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.Slices.wrappedLongArray; @@ -86,6 +87,7 @@ private static InternalAggregationFunction generateAggregation(Type type) generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(type), INPUT_FUNCTION.bindTo(type), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/CountAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/CountAggregation.java index 1be8c7e77822..07c8219a21d1 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/CountAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/CountAggregation.java @@ -20,6 +20,7 @@ import io.prestosql.spi.function.CombineFunction; import io.prestosql.spi.function.InputFunction; import io.prestosql.spi.function.OutputFunction; +import io.prestosql.spi.function.RemoveInputFunction; import io.prestosql.spi.type.StandardTypes; import static io.prestosql.spi.type.BigintType.BIGINT; @@ -35,6 +36,12 @@ public static void input(@AggregationState LongState state) state.setLong(state.getLong() + 1); } + @RemoveInputFunction + public static void removeInput(@AggregationState LongState state) + { + state.setLong(state.getLong() - 1); + } + @CombineFunction public static void combine(@AggregationState LongState state, @AggregationState LongState otherState) { diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/CountColumn.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/CountColumn.java index c072b3f20f8a..cf93fe76cdd9 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/CountColumn.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/CountColumn.java @@ -30,6 +30,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.typeVariable; @@ -48,6 +49,7 @@ public class CountColumn public static final CountColumn COUNT_COLUMN = new CountColumn(); private static final String NAME = "count"; private static final MethodHandle INPUT_FUNCTION = methodHandle(CountColumn.class, "input", LongState.class, Block.class, int.class); + private static final MethodHandle REMOVE_INPUT_FUNCTION = methodHandle(CountColumn.class, "removeInput", LongState.class, Block.class, int.class); private static final MethodHandle COMBINE_FUNCTION = methodHandle(CountColumn.class, "combine", LongState.class, LongState.class); private static final MethodHandle OUTPUT_FUNCTION = methodHandle(CountColumn.class, "output", LongState.class, BlockBuilder.class); @@ -87,6 +89,7 @@ private static InternalAggregationFunction generateAggregation(Type type) generateAggregationName(NAME, BIGINT.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(type), INPUT_FUNCTION, + Optional.of(REMOVE_INPUT_FUNCTION), COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of(new AccumulatorStateDescriptor( @@ -109,6 +112,11 @@ public static void input(LongState state, Block block, int index) state.setLong(state.getLong() + 1); } + public static void removeInput(LongState state, Block block, int index) + { + state.setLong(state.getLong() - 1); + } + public static void combine(LongState state, LongState otherState) { state.setLong(state.getLong() + otherState.getLong()); diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/CountIfAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/CountIfAggregation.java index 3590a8497406..a5c56b19a3cd 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/CountIfAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/CountIfAggregation.java @@ -20,6 +20,7 @@ import io.prestosql.spi.function.CombineFunction; import io.prestosql.spi.function.InputFunction; import io.prestosql.spi.function.OutputFunction; +import io.prestosql.spi.function.RemoveInputFunction; import io.prestosql.spi.function.SqlType; import io.prestosql.spi.type.StandardTypes; @@ -38,6 +39,14 @@ public static void input(@AggregationState LongState state, @SqlType(StandardTyp } } + @RemoveInputFunction + public static void removeInput(@AggregationState LongState state, @SqlType(StandardTypes.BOOLEAN) boolean value) + { + if (value) { + state.setLong(state.getLong() - 1); + } + } + @CombineFunction public static void combine(@AggregationState LongState state, @AggregationState LongState otherState) { diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalAverageAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalAverageAggregation.java index 577532b9c87c..2b1bc453c6a5 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalAverageAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalAverageAggregation.java @@ -38,6 +38,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -118,6 +119,7 @@ private static InternalAggregationFunction generateAggregation(Type type) generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(type), inputFunction, + Optional.empty(), COMBINE_FUNCTION, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalSumAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalSumAggregation.java index f3811f480dbd..835185f5327f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalSumAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/DecimalSumAggregation.java @@ -35,6 +35,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -107,6 +108,7 @@ private static InternalAggregationFunction generateAggregation(Type inputType, T generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(inputType), inputFunction.bindTo(inputType), + Optional.empty(), COMBINE_FUNCTION, LONG_DECIMAL_OUTPUT_FUNCTION.bindTo(outputType), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactory.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactory.java index 865a8ebcc04d..f156e2b14b26 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactory.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactory.java @@ -60,6 +60,7 @@ public class GenericAccumulatorFactory private final List sourceTypes; private final List orderByChannels; private final List orderings; + private final boolean accumulatorHasRemoveInput; @Nullable private final JoinCompiler joinCompiler; @@ -72,6 +73,7 @@ public class GenericAccumulatorFactory public GenericAccumulatorFactory( List stateDescriptors, Constructor accumulatorConstructor, + boolean accumulatorHasRemoveInput, Constructor groupedAccumulatorConstructor, List lambdaProviders, List inputChannels, @@ -86,6 +88,7 @@ public GenericAccumulatorFactory( { this.stateDescriptors = requireNonNull(stateDescriptors, "stateDescriptors is null"); this.accumulatorConstructor = requireNonNull(accumulatorConstructor, "accumulatorConstructor is null"); + this.accumulatorHasRemoveInput = accumulatorHasRemoveInput; this.groupedAccumulatorConstructor = requireNonNull(groupedAccumulatorConstructor, "groupedAccumulatorConstructor is null"); this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null")); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); @@ -108,6 +111,12 @@ public List getInputChannels() return inputChannels; } + @Override + public boolean hasRemoveInput() + { + return accumulatorHasRemoveInput; + } + @Override public Accumulator createAccumulator() { @@ -289,6 +298,12 @@ public void addInput(WindowIndex index, List channels, int startPositio throw new UnsupportedOperationException(); } + @Override + public void removeInput(WindowIndex index, List channels, int startPosition, int endPosition) + { + throw new UnsupportedOperationException(); + } + @Override public void addIntermediate(Block block) { @@ -474,6 +489,12 @@ public void addInput(WindowIndex index, List channels, int startPositio throw new UnsupportedOperationException(); } + @Override + public void removeInput(WindowIndex index, List channels, int startPosition, int endPosition) + { + throw new UnsupportedOperationException(); + } + @Override public void addIntermediate(Block block) { diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactoryBinder.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactoryBinder.java index 1e18567d7dd5..44a27e93e4ea 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactoryBinder.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/GenericAccumulatorFactoryBinder.java @@ -32,11 +32,13 @@ public class GenericAccumulatorFactoryBinder { private final List stateDescriptors; private final Constructor accumulatorConstructor; + private final boolean accumulatorHasRemoveInput; private final Constructor groupedAccumulatorConstructor; public GenericAccumulatorFactoryBinder( List stateDescriptors, Class accumulatorClass, + boolean accumulatorHasRemoveInput, Class groupedAccumulatorClass) { this.stateDescriptors = requireNonNull(stateDescriptors, "stateDescriptors is null"); @@ -48,6 +50,8 @@ public GenericAccumulatorFactoryBinder( Optional.class, /* Optional maskChannel */ List.class /* List lambdaProviders */); + this.accumulatorHasRemoveInput = accumulatorHasRemoveInput; + groupedAccumulatorConstructor = groupedAccumulatorClass.getConstructor( List.class, /* List stateDescriptors */ List.class, /* List inputChannel */ @@ -75,6 +79,7 @@ public AccumulatorFactory bind( return new GenericAccumulatorFactory( stateDescriptors, accumulatorConstructor, + accumulatorHasRemoveInput, groupedAccumulatorConstructor, lambdaProviders, argumentChannels, diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/MapAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/MapAggregationFunction.java index 3ca9256771ea..f0c6310c1495 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/MapAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/MapAggregationFunction.java @@ -31,6 +31,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.comparableTypeParameter; @@ -90,6 +91,7 @@ private static InternalAggregationFunction generateAggregation(Type keyType, Typ generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(keyType, valueType), INPUT_FUNCTION.bindTo(keyType).bindTo(valueType), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/MapUnionAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/MapUnionAggregation.java index 4859d003a354..b83d00ad8a3f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/MapUnionAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/MapUnionAggregation.java @@ -31,6 +31,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.comparableTypeParameter; @@ -84,6 +85,7 @@ private static InternalAggregationFunction generateAggregation(Type keyType, Typ generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(outputType), INPUT_FUNCTION.bindTo(keyType).bindTo(valueType), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/MergeQuantileDigestFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/MergeQuantileDigestFunction.java index 7ec00f7bde7b..73b07570235a 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/MergeQuantileDigestFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/MergeQuantileDigestFunction.java @@ -34,6 +34,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.metadata.Signature.comparableTypeParameter; @@ -92,6 +93,7 @@ private static InternalAggregationFunction generateAggregation(Type valueType, Q generateAggregationName(NAME, type.getTypeSignature(), ImmutableList.of(type.getTypeSignature())), createInputParameterMetadata(type), INPUT_FUNCTION.bindTo(type), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION.bindTo(stateSerializer), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java index 354785d6ade3..385359216ffb 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java @@ -86,6 +86,8 @@ public InternalAggregationFunction specialize(BoundVariables variables, int arit // Bind provided dependencies to aggregation method handlers MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, metadata); + Optional removeInputHandle = concreteImplementation.getRemoveInputFunction().map( + removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), variables, metadata)); MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), variables, metadata); MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), variables, metadata); @@ -100,6 +102,7 @@ public InternalAggregationFunction specialize(BoundVariables variables, int arit aggregationName, parametersMetadata, inputHandle, + removeInputHandle, combineHandle, outputHandle, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/QuantileDigestAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/QuantileDigestAggregationFunction.java index a6d958372078..e323f3b6a746 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/QuantileDigestAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/QuantileDigestAggregationFunction.java @@ -31,6 +31,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -105,6 +106,7 @@ private static InternalAggregationFunction generateAggregation(Type valueType, Q generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(inputTypes), getMethodHandle(valueType, arity), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION.bindTo(stateSerializer), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/RealAverageAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/RealAverageAggregation.java index 400a9fa4ce67..38ac07118022 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/RealAverageAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/RealAverageAggregation.java @@ -30,6 +30,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata; import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; @@ -50,6 +51,7 @@ public class RealAverageAggregation private static final String NAME = "avg"; private static final MethodHandle INPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "input", LongState.class, DoubleState.class, long.class); + private static final MethodHandle REMOVE_INPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "removeInput", LongState.class, DoubleState.class, long.class); private static final MethodHandle COMBINE_FUNCTION = methodHandle(RealAverageAggregation.class, "combine", LongState.class, DoubleState.class, LongState.class, DoubleState.class); private static final MethodHandle OUTPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "output", LongState.class, DoubleState.class, BlockBuilder.class); @@ -81,6 +83,7 @@ public InternalAggregationFunction specialize(BoundVariables boundVariables, int generateAggregationName(NAME, parseTypeSignature(StandardTypes.REAL), ImmutableList.of(parseTypeSignature(StandardTypes.REAL))), ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(STATE), new ParameterMetadata(INPUT_CHANNEL, REAL)), INPUT_FUNCTION, + Optional.of(REMOVE_INPUT_FUNCTION), COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of( @@ -118,6 +121,12 @@ public static void input(LongState count, DoubleState sum, long value) sum.setDouble(sum.getDouble() + intBitsToFloat((int) value)); } + public static void removeInput(LongState count, DoubleState sum, long value) + { + count.setLong(count.getLong() - 1); + sum.setDouble(sum.getDouble() - intBitsToFloat((int) value)); + } + public static void combine(LongState count, DoubleState sum, LongState otherCount, DoubleState otherSum) { count.setLong(count.getLong() + otherCount.getLong()); diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/ReduceAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/ReduceAggregationFunction.java index 3d81ad87fed7..32dcb9ce592f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/ReduceAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/ReduceAggregationFunction.java @@ -31,6 +31,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static io.prestosql.metadata.Signature.typeVariable; import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL; @@ -135,6 +136,7 @@ else if (stateType.getJavaType() == boolean.class) { inputMethodHandle.asType( inputMethodHandle.type() .changeParameterType(1, inputType.getJavaType())), + Optional.empty(), combineMethodHandle, outputMethodHandle, ImmutableList.of(stateDescriptor), diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/arrayagg/ArrayAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/arrayagg/ArrayAggregationFunction.java index b88f5dbb6133..e91f73ff4215 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/arrayagg/ArrayAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/arrayagg/ArrayAggregationFunction.java @@ -34,6 +34,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.typeVariable; @@ -99,6 +100,7 @@ private static InternalAggregationFunction generateAggregation(Type type, ArrayA generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, + Optional.empty(), combineFunction, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/histogram/Histogram.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/histogram/Histogram.java index 3267c1e17a1a..36626aea06c9 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/histogram/Histogram.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/histogram/Histogram.java @@ -31,6 +31,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.comparableTypeParameter; @@ -98,6 +99,7 @@ private static InternalAggregationFunction generateAggregation( generateAggregationName(functionName, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(keyType), inputFunction, + Optional.empty(), COMBINE_FUNCTION, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java index d3d233f96fbd..8095af86c448 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java @@ -45,6 +45,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; import java.util.Map; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.bytecode.Access.FINAL; @@ -152,6 +153,7 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key generateAggregationName(getSignature().getName(), valueType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(valueType, keyType), inputMethod, + Optional.empty(), combineMethod, outputMethod, ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java index bfc4cc257160..2480f637b440 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxByNAggregationFunction.java @@ -35,6 +35,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import java.util.function.Function; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -162,6 +163,7 @@ protected InternalAggregationFunction generateAggregation(Type valueType, Type k generateAggregationName(name, valueType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, INPUT_FUNCTION.bindTo(comparator).bindTo(valueType).bindTo(keyType), + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION.bindTo(outputType), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/multimapagg/MultimapAggregationFunction.java index c925eb70b8d6..d1faaeb5b9d6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -34,6 +34,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.comparableTypeParameter; @@ -96,6 +97,7 @@ private InternalAggregationFunction generateAggregation(Type keyType, Type value generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(keyType, valueType), INPUT_FUNCTION, + Optional.empty(), COMBINE_FUNCTION, OUTPUT_FUNCTION.bindTo(keyType).bindTo(valueType), ImmutableList.of(new AccumulatorStateDescriptor( diff --git a/presto-main/src/main/java/io/prestosql/operator/window/AggregateWindowFunction.java b/presto-main/src/main/java/io/prestosql/operator/window/AggregateWindowFunction.java index aac9eb6f842f..36e01a631e42 100644 --- a/presto-main/src/main/java/io/prestosql/operator/window/AggregateWindowFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/window/AggregateWindowFunction.java @@ -25,6 +25,8 @@ import java.util.List; import java.util.Optional; +import static java.lang.Math.max; +import static java.lang.Math.min; import static java.util.Objects.requireNonNull; public class AggregateWindowFunction @@ -32,6 +34,7 @@ public class AggregateWindowFunction { private final List argumentChannels; private final AccumulatorFactory accumulatorFactory; + private final boolean accumulatorHasRemoveInput; private WindowIndex windowIndex; private Accumulator accumulator; @@ -42,6 +45,7 @@ private AggregateWindowFunction(InternalAggregationFunction function, List= currentEnd)) { currentEnd = frameEnd; } else { - // different frame - resetAccumulator(); - accumulate(frameStart, frameEnd); - currentStart = frameStart; - currentEnd = frameEnd; + buildNewFrame(frameStart, frameEnd); } accumulator.evaluateFinal(output); } + private void buildNewFrame(int frameStart, int frameEnd) + { + if (accumulatorHasRemoveInput) { + // Note that all the start/end intervals are inclusive on both ends! + if (currentStart < 0) { + currentStart = 0; + currentEnd = -1; + } + int overlapStart = max(frameStart, currentStart); + int overlapEnd = min(frameEnd, currentEnd); + int prefixRemoveLength = overlapStart - currentStart; + int suffixRemoveLength = currentEnd - overlapEnd; + + if ((overlapEnd - overlapStart + 1) > (prefixRemoveLength + suffixRemoveLength)) { + // It's worth keeping the overlap, and removing the now-unused prefix + if (currentStart < frameStart) { + remove(currentStart, frameStart - 1); + } + if (frameEnd < currentEnd) { + remove(frameEnd + 1, currentEnd); + } + if (frameStart < currentStart) { + accumulate(frameStart, currentStart - 1); + } + if (currentEnd < frameEnd) { + accumulate(currentEnd + 1, frameEnd); + } + currentStart = frameStart; + currentEnd = frameEnd; + return; + } + } + + // We couldn't or didn't want to modify the accumulation: instead, discard the current accumulation and start fresh. + resetAccumulator(); + accumulate(frameStart, frameEnd); + currentStart = frameStart; + currentEnd = frameEnd; + } + private void accumulate(int start, int end) { accumulator.addInput(windowIndex, argumentChannels, start, end); } + private void remove(int start, int end) + { + accumulator.removeInput(windowIndex, argumentChannels, start, end); + } + private void resetAccumulator() { if (currentStart >= 0) { diff --git a/presto-main/src/test/java/io/prestosql/operator/aggregation/AggregationTestUtils.java b/presto-main/src/test/java/io/prestosql/operator/aggregation/AggregationTestUtils.java index 1dc8088dc097..a1c71d50096a 100644 --- a/presto-main/src/test/java/io/prestosql/operator/aggregation/AggregationTestUtils.java +++ b/presto-main/src/test/java/io/prestosql/operator/aggregation/AggregationTestUtils.java @@ -45,18 +45,20 @@ public static void assertAggregation(InternalAggregationFunction function, Objec public static void assertAggregation(InternalAggregationFunction function, Object expectedValue, Page page) { - BiFunction equalAssertion; + BiFunction equalAssertion = makeValidityAssertion(expectedValue); + + assertAggregation(function, equalAssertion, null, page, expectedValue); + } + + public static BiFunction makeValidityAssertion(Object expectedValue) + { if (expectedValue instanceof Double && !expectedValue.equals(Double.NaN)) { - equalAssertion = (actual, expected) -> Precision.equals((double) actual, (double) expected, 1e-10); + return (actual, expected) -> Precision.equals((double) actual, (double) expected, 1e-10); } else if (expectedValue instanceof Float && !expectedValue.equals(Float.NaN)) { - equalAssertion = (actual, expected) -> Precision.equals((float) actual, (float) expected, 1e-10f); + return (actual, expected) -> Precision.equals((float) actual, (float) expected, 1e-10f); } - else { - equalAssertion = Objects::equals; - } - - assertAggregation(function, equalAssertion, null, page, expectedValue); + return Objects::equals; } public static void assertAggregation(InternalAggregationFunction function, BiFunction equalAssertion, String testDescription, Page page, Object expectedValue) @@ -330,7 +332,7 @@ public static GroupByIdBlock createGroupByIdBlock(int groupId, int positions) return new GroupByIdBlock(groupId, blockBuilder.build()); } - private static int[] createArgs(InternalAggregationFunction function) + static int[] createArgs(InternalAggregationFunction function) { int[] args = new int[function.getParameterTypes().size()]; for (int i = 0; i < args.length; i++) { diff --git a/presto-main/src/test/java/io/prestosql/operator/window/TestAggregateWindowFunction.java b/presto-main/src/test/java/io/prestosql/operator/window/TestAggregateWindowFunction.java index c84f1085f475..3758b75cd01a 100644 --- a/presto-main/src/test/java/io/prestosql/operator/window/TestAggregateWindowFunction.java +++ b/presto-main/src/test/java/io/prestosql/operator/window/TestAggregateWindowFunction.java @@ -19,6 +19,7 @@ import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.IntegerType.INTEGER; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.testing.MaterializedResult.resultBuilder; @@ -57,6 +58,70 @@ public void testCountRowsOrdered() .build()); } + @Test + public void testCountRowsRolling() + { + assertWindowQuery("count(*) OVER (ORDER BY orderkey ROWS BETWEEN 4 PRECEDING AND 1 PRECEDING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(1, "O", 0L) + .row(2, "O", 1L) + .row(3, "F", 2L) + .row(4, "O", 3L) + .row(5, "F", 4L) + .row(6, "F", 4L) + .row(7, "O", 4L) + .row(32, "O", 4L) + .row(33, "F", 4L) + .row(34, "O", 4L) + .build()); + + assertWindowQuery("count(*) OVER (PARTITION BY orderstatus ORDER BY orderkey ROWS BETWEEN 4 PRECEDING AND 1 PRECEDING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(3, "F", 0L) + .row(5, "F", 1L) + .row(6, "F", 2L) + .row(33, "F", 3L) + .row(1, "O", 0L) + .row(2, "O", 1L) + .row(4, "O", 2L) + .row(7, "O", 3L) + .row(32, "O", 4L) + .row(34, "O", 4L) + .build()); + } + + @Test + public void testAverageRowsRolling() + { + assertWindowQuery("avg(orderkey) OVER (ORDER BY orderkey ROWS BETWEEN 4 PRECEDING AND 1 PRECEDING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, DOUBLE) + .row(1, "O", null) + .row(2, "O", 1.0) + .row(3, "F", 1.5) + .row(4, "O", 2.0) + .row(5, "F", 2.5) + .row(6, "F", 3.5) + .row(7, "O", 4.5) + .row(32, "O", 5.5) + .row(33, "F", 12.5) + .row(34, "O", 19.5) + .build()); + + assertWindowQuery("avg(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey ROWS BETWEEN 4 PRECEDING AND 1 PRECEDING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, DOUBLE) + .row(3, "F", null) + .row(5, "F", 3.0) + .row(6, "F", 4.0) + .row(33, "F", 4.666666666666667) + .row(1, "O", null) + .row(2, "O", 1.0) + .row(4, "O", 1.5) + .row(7, "O", 2.3333333333333334) + .row(32, "O", 3.5) + .row(34, "O", 11.25) + .build()); + } + @Test public void testCountRowsUnordered() { diff --git a/presto-spi/src/main/java/io/prestosql/spi/function/RemoveInputFunction.java b/presto-spi/src/main/java/io/prestosql/spi/function/RemoveInputFunction.java new file mode 100644 index 000000000000..1b8b1263dff1 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/function/RemoveInputFunction.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target(METHOD) +public @interface RemoveInputFunction +{ +} From e6b97d05a65d7fab8e30ccbeed92e72240b06fe1 Mon Sep 17 00:00:00 2001 From: Alan Post Date: Mon, 30 Sep 2019 12:46:29 +0200 Subject: [PATCH 3/3] Allow rolling sum for window functions Implement removeInput() in some SUM aggregations, to speed up rolling window functions. This requires additional storage in the AggregationState for the input count, so that the aggregator knows when its result should become null. Extracted from: https://github.com/prestodb/presto/pull/8974 --- .../aggregation/DoubleSumAggregation.java | 36 +++++++++++-------- .../aggregation/LongSumAggregation.java | 36 +++++++++++-------- .../operator/TestHashAggregationOperator.java | 14 ++++---- 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/DoubleSumAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/DoubleSumAggregation.java index 2213785a77c2..240f26a54e4e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/DoubleSumAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/DoubleSumAggregation.java @@ -13,13 +13,14 @@ */ package io.prestosql.operator.aggregation; -import io.prestosql.operator.aggregation.state.NullableDoubleState; +import io.prestosql.operator.aggregation.minmaxby.LongDoubleState; import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.function.AggregationFunction; import io.prestosql.spi.function.AggregationState; import io.prestosql.spi.function.CombineFunction; import io.prestosql.spi.function.InputFunction; import io.prestosql.spi.function.OutputFunction; +import io.prestosql.spi.function.RemoveInputFunction; import io.prestosql.spi.function.SqlType; import io.prestosql.spi.type.DoubleType; import io.prestosql.spi.type.StandardTypes; @@ -30,27 +31,34 @@ public final class DoubleSumAggregation private DoubleSumAggregation() {} @InputFunction - public static void sum(@AggregationState NullableDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) + public static void sum(@AggregationState LongDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) { - state.setNull(false); - state.setDouble(state.getDouble() + value); + state.setFirst(state.getFirst() + 1); + state.setSecond(state.getSecond() + value); } - @CombineFunction - public static void combine(@AggregationState NullableDoubleState state, @AggregationState NullableDoubleState otherState) + @RemoveInputFunction + public static void removeInput(@AggregationState LongDoubleState state, @SqlType(StandardTypes.DOUBLE) double value) { - if (state.isNull()) { - state.setNull(false); - state.setDouble(otherState.getDouble()); - return; - } + state.setFirst(state.getFirst() - 1); + state.setSecond(state.getSecond() - value); + } - state.setDouble(state.getDouble() + otherState.getDouble()); + @CombineFunction + public static void combine(@AggregationState LongDoubleState state, @AggregationState LongDoubleState otherState) + { + state.setFirst(state.getFirst() + otherState.getFirst()); + state.setSecond(state.getSecond() + otherState.getSecond()); } @OutputFunction(StandardTypes.DOUBLE) - public static void output(@AggregationState NullableDoubleState state, BlockBuilder out) + public static void output(@AggregationState LongDoubleState state, BlockBuilder out) { - NullableDoubleState.write(DoubleType.DOUBLE, state, out); + if (state.getFirst() == 0) { + out.appendNull(); + } + else { + DoubleType.DOUBLE.writeDouble(out, state.getSecond()); + } } } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/LongSumAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/LongSumAggregation.java index fa88532cd801..e93b0ece7c21 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/LongSumAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/LongSumAggregation.java @@ -13,13 +13,14 @@ */ package io.prestosql.operator.aggregation; -import io.prestosql.operator.aggregation.state.NullableLongState; +import io.prestosql.operator.aggregation.minmaxby.LongLongState; import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.function.AggregationFunction; import io.prestosql.spi.function.AggregationState; import io.prestosql.spi.function.CombineFunction; import io.prestosql.spi.function.InputFunction; import io.prestosql.spi.function.OutputFunction; +import io.prestosql.spi.function.RemoveInputFunction; import io.prestosql.spi.function.SqlType; import io.prestosql.spi.type.BigintType; import io.prestosql.spi.type.StandardTypes; @@ -31,27 +32,34 @@ public final class LongSumAggregation private LongSumAggregation() {} @InputFunction - public static void sum(@AggregationState NullableLongState state, @SqlType(StandardTypes.BIGINT) long value) + public static void sum(@AggregationState LongLongState state, @SqlType(StandardTypes.BIGINT) long value) { - state.setNull(false); - state.setLong(BigintOperators.add(state.getLong(), value)); + state.setFirst(state.getFirst() + 1); + state.setSecond(BigintOperators.add(state.getSecond(), value)); } - @CombineFunction - public static void combine(@AggregationState NullableLongState state, @AggregationState NullableLongState otherState) + @RemoveInputFunction + public static void removeInput(@AggregationState LongLongState state, @SqlType(StandardTypes.BIGINT) long value) { - if (state.isNull()) { - state.setNull(false); - state.setLong(otherState.getLong()); - return; - } + state.setFirst(state.getFirst() - 1); + state.setSecond(BigintOperators.subtract(state.getSecond(), value)); + } - state.setLong(BigintOperators.add(state.getLong(), otherState.getLong())); + @CombineFunction + public static void combine(@AggregationState LongLongState state, @AggregationState LongLongState otherState) + { + state.setFirst(state.getFirst() + otherState.getFirst()); + state.setSecond(BigintOperators.add(state.getSecond(), otherState.getSecond())); } @OutputFunction(StandardTypes.BIGINT) - public static void output(@AggregationState NullableLongState state, BlockBuilder out) + public static void output(@AggregationState LongLongState state, BlockBuilder out) { - NullableLongState.write(BigintType.BIGINT, state, out); + if (state.getFirst() == 0) { + out.appendNull(); + } + else { + BigintType.BIGINT.writeLong(out, state.getSecond()); + } } } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java index 685c6a8d3772..df4be1771976 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java @@ -103,6 +103,8 @@ public class TestHashAggregationOperator new Signature("sum", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); private static final InternalAggregationFunction COUNT = metadata.getAggregateFunctionImplementation( new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + private static final InternalAggregationFunction LONG_MIN = metadata.getAggregateFunctionImplementation( + new Signature("min", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); private static final int MAX_BLOCK_SIZE_IN_BYTES = 64 * 1024; @@ -240,7 +242,7 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna Step.SINGLE, true, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), - LONG_SUM.bind(ImmutableList.of(4), Optional.empty()), + LONG_MIN.bind(ImmutableList.of(4), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(4), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty()), countVarcharColumn.bind(ImmutableList.of(0), Optional.empty()), @@ -336,7 +338,7 @@ public void testMemoryLimit(boolean hashEnabled) ImmutableList.of(), Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), - LONG_SUM.bind(ImmutableList.of(3), Optional.empty()), + LONG_MIN.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty())), rowPagesBuilder.getHashChannel(), @@ -517,7 +519,7 @@ public void testMultiplePartialFlushes(boolean hashEnabled) hashChannels, ImmutableList.of(), Step.PARTIAL, - ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), + ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, @@ -599,7 +601,7 @@ public void testMergeWithMemorySpill() ImmutableList.of(), Step.SINGLE, false, - ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), + ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 1, @@ -653,7 +655,7 @@ public void testSpillerFailure() Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), - LONG_SUM.bind(ImmutableList.of(3), Optional.empty()), + LONG_MIN.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty())), rowPagesBuilder.getHashChannel(), @@ -700,7 +702,7 @@ private void testMemoryTracking(boolean useSystemMemory) hashChannels, ImmutableList.of(), Step.SINGLE, - ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), + ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000,