From b8d6287e16f6c3c1244dbeee135dfb9c98e4a32c Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sun, 26 Dec 2021 13:48:42 -0800 Subject: [PATCH] Add support for multiple state variables in annotated aggregations --- .../AggregationFromAnnotationsParser.java | 155 +++++++++++++----- .../aggregation/ParametricAggregation.java | 24 ++- .../TestAnnotationEngineForAggregates.java | 10 +- 3 files changed, 137 insertions(+), 52 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 21c9e2c54518..c4bce1d1ff89 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -44,6 +44,7 @@ import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.emptyToNull; @@ -51,7 +52,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation; import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription; -import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public final class AggregationFromAnnotationsParser @@ -67,8 +67,8 @@ public static List parseFunctionDefinitions(Class aggr ImmutableList.Builder functions = ImmutableList.builder(); - // There must be a single state class and combine function - AccumulatorStateDetails stateDetails = getStateDetails(aggregationDefinition); + // There must be a single set of state classes and a single combine function + List stateDetails = getStateDetails(aggregationDefinition); Optional combineFunction = getCombineFunction(aggregationDefinition, stateDetails); // Each output function defines a new aggregation function @@ -114,7 +114,7 @@ else if (combineFunction.isPresent()) { private static List buildFunctions( String name, AggregationHeader header, - AccumulatorStateDetails stateDetails, + List stateDetails, List exactImplementations, List nonExactImplementations) { @@ -183,48 +183,95 @@ private static List getAliases(AggregationFunction aggregationAnnotation return ImmutableList.copyOf(aggregationAnnotation.alias()); } - private static Optional getCombineFunction(Class clazz, AccumulatorStateDetails stateDetails) + private static Optional getCombineFunction(Class clazz, List stateDetails) { List combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class); - for (Method combineFunction : combineFunctions) { - // verify parameter types - List> parameterTypes = getNonDependencyParameterTypes(combineFunction); - List> expectedParameterTypes = nCopies(2, stateDetails.getStateClass()); - checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction); + if (combineFunctions.isEmpty()) { + return Optional.empty(); + } + checkArgument(combineFunctions.size() == 1, "There must be only one @CombineFunction in class %s", clazz.toGenericString()); + Method combineFunction = getOnlyElement(combineFunctions); + + // verify parameter types + List> parameterTypes = getNonDependencyParameterTypes(combineFunction); + List> expectedParameterTypes = Stream.concat(stateDetails.stream(), stateDetails.stream()) + .map(AccumulatorStateDetails::getStateClass) + .collect(toImmutableList()); + checkArgument(parameterTypes.equals(expectedParameterTypes), + "Expected combine function non-dependency parameters to be %s: %s", + expectedParameterTypes, + combineFunction); + + // legacy combine functions did not require parameters to be fully annotated + if (stateDetails.size() > 1) { + List> parameterAnnotations = getNonDependencyParameterAnnotations(combineFunction); + List actualStateDetails = new ArrayList<>(); + for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) { + actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex), parameterAnnotations.get(parameterIndex), combineFunction, true)); + } + List expectedStateDetails = ImmutableList.builder().addAll(stateDetails).addAll(stateDetails).build(); + checkArgument(actualStateDetails.equals(expectedStateDetails), "Expected combine function to have state parameters %s, but has %s", stateDetails, expectedStateDetails); } - checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateDetails.getStateClass().toGenericString()); - return combineFunctions.stream().findFirst(); + return Optional.of(combineFunction); } - private static List getOutputFunctions(Class clazz, AccumulatorStateDetails stateDetails) + private static List getOutputFunctions(Class clazz, List stateDetails) { List outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class); for (Method outputFunction : outputFunctions) { // verify parameter types List> parameterTypes = getNonDependencyParameterTypes(outputFunction); List> expectedParameterTypes = ImmutableList.>builder() - .add(stateDetails.getStateClass()) + .addAll(stateDetails.stream().map(AccumulatorStateDetails::getStateClass).collect(toImmutableList())) .add(BlockBuilder.class) .build(); checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected output function non-dependency parameters to be %s: %s", expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()), outputFunction); + + // legacy output functions did not require parameters to be fully annotated + if (stateDetails.size() > 1) { + List> parameterAnnotations = getNonDependencyParameterAnnotations(outputFunction); + + List actualStateDetails = new ArrayList<>(); + for (int parameterIndex = 0; parameterIndex < stateDetails.size(); parameterIndex++) { + actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex), parameterAnnotations.get(parameterIndex), outputFunction, true)); + } + checkArgument(actualStateDetails.equals(stateDetails), "Expected output function to have state parameters %s, but has %s", stateDetails, actualStateDetails); + } } checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions"); return outputFunctions; } - private static List getInputFunctions(Class clazz, AccumulatorStateDetails stateDetails) + private static List getInputFunctions(Class clazz, List stateDetails) { List inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class); for (Method inputFunction : inputFunctions) { - // verify state parameter is first non-dependency parameter - Class actualStateType = getNonDependencyParameterTypes(inputFunction).get(0); - checkArgument(stateDetails.getStateClass().equals(actualStateType), - "Expected input function non-dependency parameters to begin with state type %s: %s", - stateDetails.getStateClass().getSimpleName(), + // verify state parameter types + List> parameterTypes = getNonDependencyParameterTypes(inputFunction) + .subList(0, stateDetails.size()); + List> expectedParameterTypes = ImmutableList.>builder() + .addAll(stateDetails.stream().map(AccumulatorStateDetails::getStateClass).collect(toImmutableList())) + .build() + .subList(0, stateDetails.size()); + checkArgument(parameterTypes.equals(expectedParameterTypes), + "Expected input function non-dependency parameters to begin with state types %s: %s", + expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()), inputFunction); + + // g input functions did not require parameters to be fully annotated + if (stateDetails.size() > 1) { + List> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction) + .subList(0, stateDetails.size()); + + List actualStateDetails = new ArrayList<>(); + for (int parameterIndex = 0; parameterIndex < stateDetails.size(); parameterIndex++) { + actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex), parameterAnnotations.get(parameterIndex), inputFunction, false)); + } + checkArgument(actualStateDetails.equals(stateDetails), "Expected input function to have state parameters %s, but has %s", stateDetails, actualStateDetails); + } } checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions"); @@ -249,6 +296,14 @@ private static List> getNonDependencyParameterTypes(Method function) .collect(toImmutableList()); } + private static List> getNonDependencyParameterAnnotations(Method function) + { + Annotation[][] parameterAnnotations = function.getParameterAnnotations(); + return getNonDependencyParameters(function) + .mapToObj(index -> ImmutableList.copyOf(parameterAnnotations[index])) + .collect(toImmutableList()); + } + private static Optional getRemoveInputFunction(Class clazz, Method inputFunction) { // Only include methods which take the same parameters as the corresponding input function @@ -258,29 +313,51 @@ private static Optional getRemoveInputFunction(Class clazz, Method in .collect(MoreCollectors.toOptional()); } - private static AccumulatorStateDetails getStateDetails(Class clazz) + private static List getStateDetails(Class clazz) { - ImmutableSet.Builder builder = ImmutableSet.builder(); + ImmutableSet.Builder> builder = ImmutableSet.builder(); for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) { - checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters"); - int aggregationStateParamIndex = AggregationImplementation.Parser.findAggregationStateParamId(inputFunction); - Class stateClass = inputFunction.getParameterTypes()[aggregationStateParamIndex].asSubclass(AccumulatorState.class); - - Optional stateType = Arrays.stream(inputFunction.getParameterAnnotations()[aggregationStateParamIndex]) - .filter(AggregationState.class::isInstance) - .map(AggregationState.class::cast) - .findFirst() - .map(AggregationState::value) - .filter(type -> !type.isEmpty()) - .map(TypeSignature::new); - - builder.add(new AccumulatorStateDetails(stateClass, stateType)); + List> parameterTypes = getNonDependencyParameterTypes(inputFunction); + checkArgument(!parameterTypes.isEmpty(), "Input function has no parameters"); + List> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction); + + ImmutableList.Builder stateParameters = ImmutableList.builder(); + for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) { + Class parameterType = parameterTypes.get(parameterIndex); + if (!AccumulatorState.class.isAssignableFrom(parameterType)) { + continue; + } + + stateParameters.add(toAccumulatorStateDetails(parameterType, parameterAnnotations.get(parameterIndex), inputFunction, false)); + } + List states = stateParameters.build(); + checkArgument(!states.isEmpty(), "Input function must have at least one state parameter"); + builder.add(states); + } + Set> functionStateClasses = builder.build(); + checkArgument(!functionStateClasses.isEmpty(), "No input functions found"); + checkArgument(functionStateClasses.size() == 1, "There must be exactly one set of @AccumulatorState in class %s", clazz.toGenericString()); + + return getOnlyElement(functionStateClasses); + } + + private static AccumulatorStateDetails toAccumulatorStateDetails(Class parameterType, List parameterAnnotations, Method method, boolean requireAnnotation) + { + Optional state = parameterAnnotations.stream() + .filter(AggregationState.class::isInstance) + .map(AggregationState.class::cast) + .findFirst(); + + if (requireAnnotation) { + checkArgument(state.isPresent(), "AggregationState must be present on AccumulatorState parameters: %s", method); } - Set stateClasses = builder.build(); - checkArgument(!stateClasses.isEmpty(), "No input functions found"); - checkArgument(stateClasses.size() == 1, "There must be exactly one @AccumulatorState in class %s", clazz.toGenericString()); - return getOnlyElement(stateClasses); + Optional stateSqlType = state.map(AggregationState::value) + .filter(type -> !type.isEmpty()) + .map(TypeSignature::new); + + AccumulatorStateDetails accumulatorStateDetails = new AccumulatorStateDetails(parameterType.asSubclass(AccumulatorState.class), stateSqlType); + return accumulatorStateDetails; } public static class AccumulatorStateDetails diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index 3609065d6e97..63a6778ac4f3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -47,6 +47,7 @@ import java.util.StringJoiner; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.operator.ParametricFunctionHelpers.bindDependencies; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; import static io.trino.operator.aggregation.state.StateCompiler.generateInOutStateFactory; @@ -62,18 +63,18 @@ public class ParametricAggregation extends SqlAggregationFunction { private final ParametricImplementationsGroup implementations; - private final AccumulatorStateDetails stateDetails; + private final List stateDetails; public ParametricAggregation( Signature signature, AggregationHeader details, - AccumulatorStateDetails stateDetails, + List stateDetails, ParametricImplementationsGroup implementations) { super( createFunctionMetadata(signature, details, implementations.getFunctionNullability()), createAggregationFunctionMetadata(details, stateDetails)); - this.stateDetails = requireNonNull(stateDetails, "stateDetails is null"); + this.stateDetails = ImmutableList.copyOf(requireNonNull(stateDetails, "stateDetails is null")); checkArgument(implementations.getFunctionNullability().isReturnNullable(), "currently aggregates are required to be nullable"); this.implementations = requireNonNull(implementations, "implementations is null"); } @@ -106,14 +107,16 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Aggr return functionMetadata.build(); } - private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, AccumulatorStateDetails stateDetails) + private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, List stateDetails) { AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder(); if (details.isOrderSensitive()) { builder.orderSensitive(); } if (details.isDecomposable()) { - builder.intermediateType(getSerializedType(stateDetails)); + for (AccumulatorStateDetails stateDetail : stateDetails) { + builder.intermediateType(getSerializedType(stateDetail)); + } } return builder.build(); } @@ -150,7 +153,9 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature); // Build state factory and serializer - AccumulatorStateDescriptor accumulatorStateDescriptor = generateAccumulatorStateDescriptor(getFunctionMetadata().getSignature(), boundSignature, stateDetails); + List> accumulatorStateDescriptors = stateDetails.stream() + .map(state -> generateAccumulatorStateDescriptor(getFunctionMetadata().getSignature(), boundSignature, state)) + .collect(toImmutableList()); // Bind provided dependencies to aggregation method handlers FunctionMetadata metadata = getFunctionMetadata(); @@ -179,7 +184,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep removeInputHandle, combineHandle, outputHandle, - ImmutableList.of(accumulatorStateDescriptor)); + accumulatorStateDescriptors); } private static AccumulatorStateDescriptor generateAccumulatorStateDescriptor(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails) @@ -226,9 +231,10 @@ private static AccumulatorStateDescriptor genera generateStateFactory(stateClass)); } - public Class getStateClass() + @VisibleForTesting + public List getStateDetails() { - return stateDetails.getStateClass(); + return stateDetails; } @VisibleForTesting diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java index 333f8f2b2e3f..c0f6a81b7c17 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java @@ -27,6 +27,7 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.Signature; import io.trino.metadata.SqlAggregationFunction; +import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationImplementation; import io.trino.operator.aggregation.ParametricAggregation; import io.trino.operator.aggregation.state.LongState; @@ -64,6 +65,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -174,7 +176,7 @@ public static void output(@AggregationState NullableDoubleState state, BlockBuil public void testInputParameterOrderEnforced() { assertThatThrownBy(() -> parseFunctionDefinitions(InputParametersWrongOrder.class)) - .hasMessage("Expected input function non-dependency parameters to begin with state type NullableDoubleState: " + + .hasMessage("Expected input function non-dependency parameters to begin with state types [NullableDoubleState]: " + "public static void io.trino.operator.TestAnnotationEngineForAggregates$InputParametersWrongOrder.input(double,io.trino.operator.aggregation.state.NullableDoubleState)"); } @@ -342,7 +344,7 @@ public void testSimpleGenericAggregationFunctionParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with two generic implementations"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - assertEquals(aggregation.getStateClass(), NullableLongState.class); + assertEquals(aggregation.getStateDetails(), ImmutableList.of(new AccumulatorStateDetails(NullableLongState.class, Optional.empty()))); ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 0, 2); AggregationImplementation implementationDouble = implementations.getGenericImplementations().stream() @@ -1007,7 +1009,7 @@ public void testFixedTypeParameterInjectionAggregateFunctionParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with fixed parameter type injected"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - assertEquals(aggregation.getStateClass(), NullableDoubleState.class); + assertEquals(aggregation.getStateDetails(), ImmutableList.of(new AccumulatorStateDetails(NullableDoubleState.class, Optional.empty()))); ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 1, 0, 0); AggregationImplementation implementationDouble = implementations.getExactImplementations().get(expectedSignature); @@ -1071,7 +1073,7 @@ public void testPartiallyFixedTypeParameterInjectionAggregateFunctionParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with fixed parameter type injected"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - assertEquals(aggregation.getStateClass(), NullableDoubleState.class); + assertEquals(aggregation.getStateDetails(), ImmutableList.of(new AccumulatorStateDetails(NullableDoubleState.class, Optional.empty()))); ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 0, 1); AggregationImplementation implementationDouble = getOnlyElement(implementations.getGenericImplementations());