From 6f3a89b3d3ff3dc372219cfa7d53f4cd1fa85cd0 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Thu, 3 Nov 2022 23:22:56 -0700 Subject: [PATCH] Fix incorrect min/max with limit in window functions Window functions have an optimization whereby if the frame is the same, it avoids resetting the function and just calls output(). The implementation of output in min and max with N is destructive, which causes subsequent calls to return empty results. In this change, we update the output method to walk over the sorted contents of the heap instead of clearing it. --- .../minmaxn/MaxNAggregationFunction.java | 2 +- .../aggregation/minmaxn/MinMaxNState.java | 3 +- .../minmaxn/MinMaxNStateFactory.java | 10 +-- .../minmaxn/MinNAggregationFunction.java | 2 +- .../aggregation/minmaxn/MinNStateFactory.java | 1 + .../aggregation/minmaxn/TypedHeap.java | 40 +++------- .../aggregation/minmaxn/TestTypedHeap.java | 4 +- .../io/trino/sql/query/TestMinMaxNWindow.java | 73 +++++++++++++++++++ 8 files changed, 94 insertions(+), 41 deletions(-) create mode 100644 core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java index ef6b86d6cd89..65d00a67f51f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java @@ -55,6 +55,6 @@ public static void combine( @OutputFunction("array(E)") public static void output(@AggregationState("E") MaxNState state, BlockBuilder out) { - state.popAll(out); + state.writeAll(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java index 2860f00a53f0..8ac9d36daa0b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java @@ -40,9 +40,8 @@ public interface MinMaxNState /** * Writes all values to the supplied block builder as an array entry. - * After this method is called, the current state will be empty. */ - void popAll(BlockBuilder out); + void writeAll(BlockBuilder out); /** * Write this state to the specified block builder. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java index 193e086fceac..3eb594f92825 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java @@ -115,7 +115,7 @@ public final void merge(MinMaxNState other) } @Override - public final void popAll(BlockBuilder out) + public final void writeAll(BlockBuilder out) { TypedHeap typedHeap = getTypedHeap(); if (typedHeap == null || typedHeap.isEmpty()) { @@ -125,9 +125,7 @@ public final void popAll(BlockBuilder out) BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - size -= typedHeap.getEstimatedSize(); - typedHeap.popAllReverse(arrayBlockBuilder); - size += typedHeap.getEstimatedSize(); + typedHeap.writeAll(arrayBlockBuilder); out.closeEntry(); } @@ -234,7 +232,7 @@ public final void merge(MinMaxNState other) } @Override - public final void popAll(BlockBuilder out) + public final void writeAll(BlockBuilder out) { if (typedHeap == null || typedHeap.isEmpty()) { out.appendNull(); @@ -242,7 +240,7 @@ public final void popAll(BlockBuilder out) } BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - typedHeap.popAllReverse(arrayBlockBuilder); + typedHeap.writeAll(arrayBlockBuilder); out.closeEntry(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java index aae0eefced43..91bac2856f00 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java @@ -55,6 +55,6 @@ public static void combine( @OutputFunction("array(E)") public static void output(@AggregationState("E") MinNState state, BlockBuilder out) { - state.popAll(out); + state.writeAll(out); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java index fa432a9ce66b..9727aabc8930 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java @@ -56,6 +56,7 @@ public MinNStateFactory( "second argument of min_n must be less than or equal to %s; found %s", MAX_NUMBER_OF_VALUES, n); + return new TypedHeap(true, compare, elementType, toIntExact(n)); }; deserializer = rowBlock -> TypedHeap.deserialize(true, compare, elementType, rowBlock); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java index 27f7fbc563fb..3a280e9368b0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java @@ -18,6 +18,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrays; import org.openjdk.jol.info.ClassLayout; import java.lang.invoke.MethodHandle; @@ -111,41 +112,17 @@ public static TypedHeap deserialize(boolean min, MethodHandle compare, Type elem return new TypedHeap(min, compare, elementType, capacity, heapBlock.getPositionCount(), heapIndex, heapBlockBuilder); } - public void popAllReverse(BlockBuilder resultBlockBuilder) + public void writeAll(BlockBuilder resultBlockBuilder) { int[] indexes = new int[positionCount]; - while (positionCount > 0) { - indexes[positionCount - 1] = heapIndex[0]; - positionCount--; - heapIndex[0] = heapIndex[positionCount]; - siftDown(); - } + System.arraycopy(heapIndex, 0, indexes, 0, positionCount); + IntArrays.quickSort(indexes, (a, b) -> compare(heapBlockBuilder, a, heapBlockBuilder, b)); for (int index : indexes) { elementType.appendTo(heapBlockBuilder, index, resultBlockBuilder); } } - public void popAll(BlockBuilder resultBlockBuilder) - { - while (positionCount > 0) { - pop(resultBlockBuilder); - } - } - - public void pop(BlockBuilder resultBlockBuilder) - { - elementType.appendTo(heapBlockBuilder, heapIndex[0], resultBlockBuilder); - remove(); - } - - private void remove() - { - positionCount--; - heapIndex[0] = heapIndex[positionCount]; - siftDown(); - } - public void add(Block block, int position) { checkArgument(!block.isNull(position)); @@ -237,11 +214,11 @@ private void compactIfNecessary() heapBlockBuilder = newHeapBlockBuilder; } - private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) + private int compare(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) { try { long result = (long) compare.invokeExact(leftBlock, leftPosition, rightBlock, rightPosition); - return min ? result < 0 : result > 0; + return (int) (min ? result : -result); } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); @@ -249,6 +226,11 @@ private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block r } } + private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) + { + return compare(leftBlock, leftPosition, rightBlock, rightPosition) < 0; + } + public TypedHeap copy() { BlockBuilder heapBlockBuilderCopy = null; diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java index 6e74e4a70108..8426123033c5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java @@ -97,11 +97,11 @@ private static void test(IntStream inputStream, boolean min, MethodHandle compar heap.addAll(blockBuilder); BlockBuilder resultBlockBuilder = BIGINT.createBlockBuilder(null, OUTPUT_SIZE); - heap.popAll(resultBlockBuilder); + heap.writeAll(resultBlockBuilder); Block resultBlock = resultBlockBuilder.build(); assertEquals(resultBlock.getPositionCount(), OUTPUT_SIZE); - for (int i = 0; i < OUTPUT_SIZE; i++) { + for (int i = OUTPUT_SIZE - 1; i >= 0; i--) { assertEquals(BIGINT.getLong(resultBlock, i), outputIterator.nextInt()); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java b/core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java new file mode 100644 index 000000000000..ac05dd5cb87b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestMinMaxNWindow.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) +public class TestMinMaxNWindow +{ + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testMax() + { + assertThat(assertions.query(""" + SELECT max(x, 3) OVER () FROM (VALUES 1, 2, 3, 4, 5) t(x) + """)) + .matches(""" + VALUES + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]), + (ARRAY[5, 4, 3]) + """); + } + + @Test + public void testMin() + { + assertThat(assertions.query(""" + SELECT min(x, 3) OVER () FROM (VALUES 1, 2, 3, 4, 5) t(x) + """)) + .matches(""" + VALUES + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]), + (ARRAY[1, 2, 3]) + """); + } +}