diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java index ef6b86d6cd89..65d00a67f51f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java @@ -55,6 +55,6 @@ public static void combine( @OutputFunction("array(E)") public static void output(@AggregationState("E") MaxNState state, BlockBuilder out) { - state.popAll(out); + state.writeAll(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java index 2860f00a53f0..8ac9d36daa0b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java @@ -40,9 +40,8 @@ public interface MinMaxNState /** * Writes all values to the supplied block builder as an array entry. - * After this method is called, the current state will be empty. */ - void popAll(BlockBuilder out); + void writeAll(BlockBuilder out); /** * Write this state to the specified block builder. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java index 193e086fceac..3eb594f92825 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java @@ -115,7 +115,7 @@ public final void merge(MinMaxNState other) } @Override - public final void popAll(BlockBuilder out) + public final void writeAll(BlockBuilder out) { TypedHeap typedHeap = getTypedHeap(); if (typedHeap == null || typedHeap.isEmpty()) { @@ -125,9 +125,7 @@ public final void popAll(BlockBuilder out) BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - size -= typedHeap.getEstimatedSize(); - typedHeap.popAllReverse(arrayBlockBuilder); - size += typedHeap.getEstimatedSize(); + typedHeap.writeAll(arrayBlockBuilder); out.closeEntry(); } @@ -234,7 +232,7 @@ public final void merge(MinMaxNState other) } @Override - public final void popAll(BlockBuilder out) + public final void writeAll(BlockBuilder out) { if (typedHeap == null || typedHeap.isEmpty()) { out.appendNull(); @@ -242,7 +240,7 @@ public final void popAll(BlockBuilder out) } BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - typedHeap.popAllReverse(arrayBlockBuilder); + typedHeap.writeAll(arrayBlockBuilder); out.closeEntry(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java index aae0eefced43..91bac2856f00 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java @@ -55,6 +55,6 @@ public static void combine( @OutputFunction("array(E)") public static void output(@AggregationState("E") MinNState state, BlockBuilder out) { - state.popAll(out); + state.writeAll(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java index fa432a9ce66b..9727aabc8930 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java @@ -56,6 +56,7 @@ public MinNStateFactory( "second argument of min_n must be less than or equal to %s; found %s", MAX_NUMBER_OF_VALUES, n); + return new TypedHeap(true, compare, elementType, toIntExact(n)); }; deserializer = rowBlock -> TypedHeap.deserialize(true, compare, elementType, rowBlock); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java index 27f7fbc563fb..3a280e9368b0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java @@ -18,6 +18,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrays; import org.openjdk.jol.info.ClassLayout; import java.lang.invoke.MethodHandle; @@ -111,41 +112,17 @@ public static TypedHeap deserialize(boolean min, MethodHandle compare, Type elem return new TypedHeap(min, compare, elementType, capacity, heapBlock.getPositionCount(), heapIndex, heapBlockBuilder); } - public void popAllReverse(BlockBuilder resultBlockBuilder) + public void writeAll(BlockBuilder resultBlockBuilder) { int[] indexes = new int[positionCount]; - while (positionCount > 0) { - indexes[positionCount - 1] = heapIndex[0]; - positionCount--; - heapIndex[0] = heapIndex[positionCount]; - siftDown(); - } + System.arraycopy(heapIndex, 0, indexes, 0, positionCount); + IntArrays.quickSort(indexes, (a, b) -> compare(heapBlockBuilder, a, heapBlockBuilder, b)); for (int index : indexes) { elementType.appendTo(heapBlockBuilder, index, resultBlockBuilder); } } - public void popAll(BlockBuilder resultBlockBuilder) - { - while (positionCount > 0) { - pop(resultBlockBuilder); - } - } - - public void pop(BlockBuilder resultBlockBuilder) - { - elementType.appendTo(heapBlockBuilder, heapIndex[0], resultBlockBuilder); - remove(); - } - - private void remove() - { - positionCount--; - heapIndex[0] = heapIndex[positionCount]; - siftDown(); - } - public void add(Block block, int position) { checkArgument(!block.isNull(position)); @@ -237,11 +214,11 @@ private void compactIfNecessary() heapBlockBuilder = newHeapBlockBuilder; } - private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) + private int compare(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) { try { long result = (long) compare.invokeExact(leftBlock, leftPosition, rightBlock, rightPosition); - return min ? result < 0 : result > 0; + return (int) (min ? result : -result); } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); @@ -249,6 +226,11 @@ private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block r } } + private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) + { + return compare(leftBlock, leftPosition, rightBlock, rightPosition) < 0; + } + public TypedHeap copy() { BlockBuilder heapBlockBuilderCopy = null; diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java index 6e74e4a70108..8426123033c5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java @@ -97,11 +97,11 @@ private static void test(IntStream inputStream, boolean min, MethodHandle compar heap.addAll(blockBuilder); BlockBuilder resultBlockBuilder = BIGINT.createBlockBuilder(null, OUTPUT_SIZE); - heap.popAll(resultBlockBuilder); + heap.writeAll(resultBlockBuilder); Block resultBlock = resultBlockBuilder.build(); assertEquals(resultBlock.getPositionCount(), OUTPUT_SIZE); - for (int i = 0; i < OUTPUT_SIZE; i++) { + for (int i = OUTPUT_SIZE - 1; i >= 0; i--) { assertEquals(BIGINT.getLong(resultBlock, i), outputIterator.nextInt()); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java b/core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java new file mode 100644 index 000000000000..ac05dd5cb87b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) +public class TestMinMaxNWindow +{ + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testMax() + { + assertThat(assertions.query(""" + SELECT max(x, 3) OVER () FROM (VALUES 1, 2, 3, 4, 5) t(x) + """)) + .matches(""" + VALUES + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]) + """); + } + + @Test + public void testMin() + { + assertThat(assertions.query(""" + SELECT min(x, 3) OVER () FROM (VALUES 1, 2, 3, 4, 5) t(x) + """)) + .matches(""" + VALUES + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]) + """); + } +}