From cb8c031092cd02fde8842679e3e750b1ee97aeb9 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sun, 26 Dec 2021 13:59:41 -0800 Subject: [PATCH] Convert min/max_by aggregation to annotated function --- .../trino/metadata/SystemFunctionBundle.java | 8 +- .../minmaxby/AbstractMinMaxBy.java | 400 ------------------ .../minmaxby/MaxByAggregationFunction.java | 83 +++- .../minmaxby/MinByAggregationFunction.java | 83 +++- .../checkMathFunctionsRegistered.result | 4 +- 5 files changed, 163 insertions(+), 415 deletions(-) delete mode 100644 core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java index 98a21c28dd89..a7ab95938f6d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java @@ -79,7 +79,9 @@ import io.trino.operator.aggregation.VarianceAggregation; import io.trino.operator.aggregation.histogram.Histogram; import io.trino.operator.aggregation.listagg.ListaggAggregationFunction; +import io.trino.operator.aggregation.minmaxby.MaxByAggregationFunction; import io.trino.operator.aggregation.minmaxby.MaxByNAggregationFunction; +import io.trino.operator.aggregation.minmaxby.MinByAggregationFunction; import io.trino.operator.aggregation.minmaxby.MinByNAggregationFunction; import io.trino.operator.aggregation.multimapagg.MultimapAggregationFunction; import io.trino.operator.scalar.ArrayAllMatchFunction; @@ -270,8 +272,6 @@ import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION; import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG; import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG; -import static io.trino.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY; -import static io.trino.operator.aggregation.minmaxby.MinByAggregationFunction.MIN_BY; import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION; import static io.trino.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR; import static io.trino.operator.scalar.ArrayFlattenFunction.ARRAY_FLATTEN_FUNCTION; @@ -544,9 +544,11 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .aggregates(ChecksumAggregationFunction.class) .aggregates(ArbitraryAggregationFunction.class) .functions(GREATEST, LEAST) - .functions(MAX_BY, MIN_BY, new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators)) + .functions(new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators)) .aggregates(MinAggregationFunction.class) .aggregates(MaxAggregationFunction.class) + .aggregates(MinByAggregationFunction.class) + .aggregates(MaxByAggregationFunction.class) .functions(new MaxNAggregationFunction(blockTypeOperators), new MinNAggregationFunction(blockTypeOperators)) .aggregates(CountColumn.class) .functions(JSON_TO_ROW, JSON_STRING_TO_ROW, ROW_TO_ROW_CAST) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java deleted file mode 100644 index 3d84ef00f1a2..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation.minmaxby; - -import com.google.common.collect.ImmutableList; -import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; -import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; -import io.trino.operator.aggregation.state.BlockPositionState; -import io.trino.operator.aggregation.state.BlockPositionStateSerializer; -import io.trino.operator.aggregation.state.NullableBooleanState; -import io.trino.operator.aggregation.state.NullableBooleanStateSerializer; -import io.trino.operator.aggregation.state.NullableDoubleState; -import io.trino.operator.aggregation.state.NullableDoubleStateSerializer; -import io.trino.operator.aggregation.state.NullableLongState; -import io.trino.operator.aggregation.state.NullableLongStateSerializer; -import io.trino.operator.aggregation.state.NullableState; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.function.AccumulatorState; -import io.trino.spi.function.AccumulatorStateSerializer; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; -import io.trino.util.MinMaxCompare; - -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.Optional; - -import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.util.MinMaxCompare.getMinMaxCompareFunctionDependencies; -import static io.trino.util.MinMaxCompare.getMinMaxCompareOperatorType; -import static java.lang.invoke.MethodHandles.explicitCastArguments; -import static java.lang.invoke.MethodHandles.insertArguments; -import static java.lang.invoke.MethodHandles.lookup; -import static java.lang.invoke.MethodType.methodType; - -public abstract class AbstractMinMaxBy - extends SqlAggregationFunction -{ - private final boolean min; - - protected AbstractMinMaxBy(boolean min, String description) - { - super( - FunctionMetadata.aggregateBuilder() - .signature(Signature.builder() - .name((min ? "min" : "max") + "_by") - .orderableTypeParameter("K") - .typeVariable("V") - .returnType(new TypeSignature("V")) - .argumentType(new TypeSignature("V")) - .argumentType(new TypeSignature("K")) - .build()) - .argumentNullability(true, false) - .description(description) - .build(), - AggregationFunctionMetadata.builder() - .intermediateType(new TypeSignature("K")) - .intermediateType(new TypeSignature("V")) - .build()); - this.min = min; - } - - @Override - public FunctionDependencyDeclaration getFunctionDependencies() - { - return getMinMaxCompareFunctionDependencies(new TypeSignature("K"), min); - } - - @Override - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) - { - try { - Type keyType = boundSignature.getArgumentType(1); - Type valueType = boundSignature.getArgumentType(0); - - MethodHandle inputMethod = generateInput(keyType, valueType, functionDependencies); - MethodHandle combineMethod = generateCombine(keyType, valueType, functionDependencies); - MethodHandle outputMethod = generateOutput(keyType, valueType); - - return new AggregationMetadata( - inputMethod, - Optional.empty(), - Optional.of(combineMethod), - outputMethod, - ImmutableList.of( - getAccumulatorStateDescriptor(keyType), - getAccumulatorStateDescriptor(valueType))); - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - } - - private static AccumulatorStateDescriptor getAccumulatorStateDescriptor(Type type) - { - Class stateClass = getStateClass(type); - if (stateClass.equals(BlockPositionState.class)) { - return new AccumulatorStateDescriptor<>( - BlockPositionState.class, - new BlockPositionStateSerializer(type), - generateStateFactory(BlockPositionState.class)); - } - return getAccumulatorStateDescriptor(stateClass, type); - } - - private static AccumulatorStateDescriptor getAccumulatorStateDescriptor(Class stateClass, Type type) - { - return new AccumulatorStateDescriptor<>( - stateClass, - getStateSerializer(stateClass, type), - generateStateFactory(stateClass)); - } - - private MethodHandle generateInput(Type keyType, Type valueType, FunctionDependencies functionDependencies) - throws ReflectiveOperationException - { - MethodHandle input = lookup().findStatic( - AbstractMinMaxBy.class, - "input", - methodType(void.class, - MethodHandle.class, - MethodHandle.class, - MethodHandle.class, - NullableState.class, - NullableState.class, - Block.class, - Block.class, - int.class)); - - Class keyState = getStateClass(keyType); - Class valueState = getStateClass(valueType); - - MethodHandle compareStateBlockPosition = generateCompareStateBlockPosition(keyType, functionDependencies, keyState); - MethodHandle setKeyState = getSetStateValue(keyType, keyState); - MethodHandle setValueState = getSetStateValue(valueType, valueState); - input = insertArguments(input, 0, compareStateBlockPosition, setKeyState, setValueState); - return explicitCastArguments(input, methodType(void.class, keyState, valueState, Block.class, Block.class, int.class)); - } - - private static void input( - MethodHandle compareStateBlockPosition, - MethodHandle setKeyState, - MethodHandle setValueState, - NullableState keyState, - NullableState valueState, - Block value, - Block key, - int position) - throws Throwable - { - if (keyState.isNull() || (boolean) compareStateBlockPosition.invoke(key, position, keyState)) { - setKeyState.invoke(keyState, key, position); - setValueState.invoke(valueState, value, position); - } - } - - private MethodHandle generateCombine(Type keyType, Type valueType, FunctionDependencies functionDependencies) - throws ReflectiveOperationException - { - MethodHandle combine = lookup().findStatic( - AbstractMinMaxBy.class, - "combine", - methodType(void.class, - MethodHandle.class, - MethodHandle.class, - MethodHandle.class, - NullableState.class, - NullableState.class, - NullableState.class, - NullableState.class)); - - Class keyState = getStateClass(keyType); - Class valueState = getStateClass(valueType); - - MethodHandle compareStateBlockPosition = generateCompareStateState(keyType, functionDependencies, keyState); - MethodHandle setKeyState = lookup().findVirtual(keyState, "set", methodType(void.class, keyState)); - MethodHandle setValueState = lookup().findVirtual(valueState, "set", methodType(void.class, valueState)); - combine = insertArguments(combine, 0, compareStateBlockPosition, setKeyState, setValueState); - return explicitCastArguments(combine, methodType(void.class, keyState, valueState, keyState, valueState)); - } - - private static void combine( - MethodHandle compareStateState, - MethodHandle setKeyState, - MethodHandle setValueState, - NullableState keyState, - NullableState valueState, - NullableState otherKeyState, - NullableState otherValueState) - throws Throwable - { - if (otherKeyState.isNull()) { - return; - } - if (keyState.isNull() || (boolean) compareStateState.invoke(otherKeyState, keyState)) { - setKeyState.invoke(keyState, otherKeyState); - setValueState.invoke(valueState, otherValueState); - } - } - - private static MethodHandle generateOutput(Type keyType, Type valueType) - throws ReflectiveOperationException - { - Class keyState = getStateClass(keyType); - Class valueState = getStateClass(valueType); - MethodHandle writeState = lookup().findStatic(AbstractMinMaxBy.class, "writeState", methodType(void.class, Type.class, valueState, BlockBuilder.class)) - .bindTo(valueType); - MethodHandle output = lookup().findStatic( - AbstractMinMaxBy.class, - "output", - methodType(void.class, MethodHandle.class, NullableState.class, NullableState.class, BlockBuilder.class)); - output = output.bindTo(writeState); - return explicitCastArguments(output, methodType(void.class, keyState, valueState, BlockBuilder.class)); - } - - private static void output( - MethodHandle valueWriter, - NullableState keyState, - NullableState valueState, - BlockBuilder blockBuilder) - throws Throwable - { - if (keyState.isNull() || valueState.isNull()) { - blockBuilder.appendNull(); - return; - } - valueWriter.invoke(valueState, blockBuilder); - } - - @UsedByGeneratedCode - private static void writeState(Type type, NullableLongState state, BlockBuilder output) - { - type.writeLong(output, state.getValue()); - } - - @UsedByGeneratedCode - private static void writeState(Type type, NullableDoubleState state, BlockBuilder output) - { - type.writeDouble(output, state.getValue()); - } - - @UsedByGeneratedCode - private static void writeState(Type type, NullableBooleanState state, BlockBuilder output) - { - type.writeBoolean(output, state.getValue()); - } - - @UsedByGeneratedCode - private static void writeState(Type type, BlockPositionState state, BlockBuilder output) - { - type.appendTo(state.getBlock(), state.getPosition(), output); - } - - private static Class getStateClass(Type type) - { - if (type.getJavaType().equals(long.class)) { - return NullableLongState.class; - } - if (type.getJavaType().equals(double.class)) { - return NullableDoubleState.class; - } - if (type.getJavaType().equals(boolean.class)) { - return NullableBooleanState.class; - } - return BlockPositionState.class; - } - - @SuppressWarnings("unchecked") - private static AccumulatorStateSerializer getStateSerializer(Class state, Type type) - { - if (NullableLongState.class.equals(state)) { - return (AccumulatorStateSerializer) new NullableLongStateSerializer(type); - } - if (NullableDoubleState.class.equals(state)) { - return (AccumulatorStateSerializer) new NullableDoubleStateSerializer(type); - } - if (NullableBooleanState.class.equals(state)) { - return (AccumulatorStateSerializer) new NullableBooleanStateSerializer(type); - } - if (BlockPositionState.class.equals(state)) { - return (AccumulatorStateSerializer) new BlockPositionStateSerializer(type); - } - throw new IllegalArgumentException("Unsupported state class: " + state); - } - - private static MethodHandle getSetStateValue(Type type, Class stateClass) - throws ReflectiveOperationException - { - if (stateClass.equals(BlockPositionState.class)) { - return lookup().findStatic(AbstractMinMaxBy.class, "setStateValue", methodType(void.class, BlockPositionState.class, Block.class, int.class)); - } - return lookup().findStatic(AbstractMinMaxBy.class, "setStateValue", methodType(void.class, Type.class, stateClass, Block.class, int.class)) - .bindTo(type); - } - - @UsedByGeneratedCode - private static void setStateValue(BlockPositionState state, Block block, int position) - { - state.setBlock(block); - state.setPosition(position); - } - - @UsedByGeneratedCode - private static void setStateValue(Type valueType, NullableLongState state, Block block, int position) - { - if (block.isNull(position)) { - state.setNull(true); - } - else { - state.setNull(false); - state.setValue(valueType.getLong(block, position)); - } - } - - @UsedByGeneratedCode - private static void setStateValue(Type valueType, NullableDoubleState state, Block block, int position) - { - if (block.isNull(position)) { - state.setNull(true); - } - else { - state.setNull(false); - state.setValue(valueType.getDouble(block, position)); - } - } - - @UsedByGeneratedCode - private static void setStateValue(Type valueType, NullableBooleanState state, Block block, int position) - { - if (block.isNull(position)) { - state.setNull(true); - } - else { - state.setNull(false); - state.setValue(valueType.getBoolean(block, position)); - } - } - - private MethodHandle generateCompareStateBlockPosition(Type type, FunctionDependencies functionDependencies, Class state) - throws ReflectiveOperationException - { - if (state.equals(BlockPositionState.class)) { - MethodHandle comparisonMethod = lookup().findStatic(AbstractMinMaxBy.class, "compareStateBlockPosition", methodType(long.class, MethodHandle.class, Block.class, int.class, BlockPositionState.class)) - .bindTo(functionDependencies.getOperatorInvoker(getMinMaxCompareOperatorType(min), ImmutableList.of(type, type), simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION)).getMethodHandle()); - return MinMaxCompare.comparisonToMinMaxResult(min, comparisonMethod); - } - MethodHandle minMaxMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, NEVER_NULL), min); - MethodHandle stateGetValue = lookup().findVirtual(state, "getValue", methodType(type.getJavaType())); - return MethodHandles.filterArguments(minMaxMethod, 2, stateGetValue); - } - - private static long compareStateBlockPosition(MethodHandle blockPositionBlockPositionOperator, Block left, int leftPosition, BlockPositionState state) - throws Throwable - { - return (long) blockPositionBlockPositionOperator.invokeExact(left, leftPosition, state.getBlock(), state.getPosition()); - } - - private MethodHandle generateCompareStateState(Type type, FunctionDependencies functionDependencies, Class state) - throws ReflectiveOperationException - { - if (state.equals(BlockPositionState.class)) { - MethodHandle comparisonMethod = lookup().findStatic(AbstractMinMaxBy.class, "compareStateState", methodType(long.class, MethodHandle.class, BlockPositionState.class, BlockPositionState.class)) - .bindTo(functionDependencies.getOperatorInvoker(getMinMaxCompareOperatorType(min), ImmutableList.of(type, type), simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION)).getMethodHandle()); - return MinMaxCompare.comparisonToMinMaxResult(min, comparisonMethod); - } - MethodHandle maxMaxMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, type, simpleConvention(FAIL_ON_NULL, NEVER_NULL, NEVER_NULL), min); - MethodHandle stateGetValue = lookup().findVirtual(state, "getValue", methodType(type.getJavaType())); - return MethodHandles.filterArguments(maxMaxMethod, 0, stateGetValue, stateGetValue); - } - - private static long compareStateState(MethodHandle blockPositionBlockPositionOperator, BlockPositionState state, BlockPositionState otherState) - throws Throwable - { - return (long) blockPositionBlockPositionOperator.invokeExact(state.getBlock(), state.getPosition(), otherState.getBlock(), otherState.getPosition()); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java index 460463c26364..1ae2dc13a717 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MaxByAggregationFunction.java @@ -13,13 +13,86 @@ */ package io.trino.operator.aggregation.minmaxby; -public class MaxByAggregationFunction - extends AbstractMinMaxBy +import io.trino.operator.aggregation.NullablePosition; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; + +@AggregationFunction("max_by") +@Description("Returns the value of the first argument, associated with the maximum value of the second argument") +public final class MaxByAggregationFunction { - public static final MaxByAggregationFunction MAX_BY = new MaxByAggregationFunction(); + private MaxByAggregationFunction() {} + + @InputFunction + @TypeParameter("V") + @TypeParameter("K") + public static void input( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockPosition @SqlType("K") Block keyBlock, + @BlockIndex int position) + throws Throwable + { + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) > 0) { + keyState.set(keyBlock, position); + valueState.set(valueBlock, position); + } + } + + @CombineFunction + public static void combine( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_FIRST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {IN_OUT, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @AggregationState("K") InOut otherKeyState, + @AggregationState("V") InOut otherValueState) + throws Throwable + { + if (otherKeyState.isNull()) { + return; + } + if (keyState.isNull() || ((long) compare.invokeExact(otherKeyState, keyState)) > 0) { + keyState.set(otherKeyState); + valueState.set(otherValueState); + } + } - public MaxByAggregationFunction() + @OutputFunction("V") + public static void output( + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + BlockBuilder out) { - super(false, "Returns the value of the first argument, associated with the maximum value of the second argument"); + valueState.get(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java index f813ac3bab93..9c9327a3ac77 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxby/MinByAggregationFunction.java @@ -13,13 +13,86 @@ */ package io.trino.operator.aggregation.minmaxby; -public class MinByAggregationFunction - extends AbstractMinMaxBy +import io.trino.operator.aggregation.NullablePosition; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.Convention; +import io.trino.spi.function.Description; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OperatorDependency; +import io.trino.spi.function.OperatorType; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; + +@AggregationFunction("min_by") +@Description("Returns the value of the first argument, associated with the minimum value of the second argument") +public final class MinByAggregationFunction { - public static final MinByAggregationFunction MIN_BY = new MinByAggregationFunction(); + private MinByAggregationFunction() {} + + @InputFunction + @TypeParameter("V") + @TypeParameter("K") + public static void input( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {BLOCK_POSITION, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @NullablePosition @BlockPosition @SqlType("V") Block valueBlock, + @BlockPosition @SqlType("K") Block keyBlock, + @BlockIndex int position) + throws Throwable + { + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) < 0) { + keyState.set(keyBlock, position); + valueState.set(valueBlock, position); + } + } + + @CombineFunction + public static void combine( + @OperatorDependency( + operator = OperatorType.COMPARISON_UNORDERED_LAST, + argumentTypes = {"K", "K"}, + convention = @Convention(arguments = {IN_OUT, IN_OUT}, result = FAIL_ON_NULL)) + MethodHandle compare, + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + @AggregationState("K") InOut otherKeyState, + @AggregationState("V") InOut otherValueState) + throws Throwable + { + if (otherKeyState.isNull()) { + return; + } + if (keyState.isNull() || ((long) compare.invokeExact(otherKeyState, keyState)) < 0) { + keyState.set(otherKeyState); + valueState.set(otherValueState); + } + } - public MinByAggregationFunction() + @OutputFunction("V") + public static void output( + @AggregationState("K") InOut keyState, + @AggregationState("V") InOut valueState, + BlockBuilder out) { - super(true, "Returns the value of the first argument, associated with the minimum value of the second argument"); + valueState.get(out); } } diff --git a/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result b/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result index 1ecebd427ec0..f635924115bf 100644 --- a/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result +++ b/testing/trino-product-tests/src/main/resources/sql-tests/testcases/math_functions/checkMathFunctionsRegistered.result @@ -47,9 +47,9 @@ log10 | double | double | scalar | true | Logarithm to base 10 | log2 | double | double | scalar | true | Logarithm to base 2 | max | t | t | aggregate | true | Returns the maximum value of the argument | - max_by | V | V, K | aggregate | true | Returns the value of the first argument, associated with the maximum value of the second argument | + max_by | v | v, k | aggregate | true | Returns the value of the first argument, associated with the maximum value of the second argument | min | t | t | aggregate | true | Returns the minimum value of the argument | - min_by | V | V, K | aggregate | true | Returns the value of the first argument, associated with the minimum value of the second argument | + min_by | v | v, k | aggregate | true | Returns the value of the first argument, associated with the minimum value of the second argument | mod | bigint | bigint, bigint | scalar | true | Remainder of given quotient | mod | double | double, double | scalar | true | Remainder of given quotient | nan | double | | scalar | true | Constant representing not-a-number |