Skip to content

Commit

Permalink
Fix incorrect min/max with limit in window functions
Browse files Browse the repository at this point in the history
Window functions have an optimization whereby if the frame is the
same, it avoids resetting the function and just calls output().

The implementation of output in min and max with N is destructive,
which causes subsequent calls to return empty results.

In this change, we update the output method to walk over the sorted
contents of the heap instead of clearing it.
  • Loading branch information
martint committed Nov 5, 2022
1 parent 55656b7 commit 6f3a89b
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ public static void combine(
@OutputFunction("array(E)")
public static void output(@AggregationState("E") MaxNState state, BlockBuilder out)
{
state.popAll(out);
state.writeAll(out);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ public interface MinMaxNState

/**
* Writes all values to the supplied block builder as an array entry.
* After this method is called, the current state will be empty.
*/
void popAll(BlockBuilder out);
void writeAll(BlockBuilder out);

/**
* Write this state to the specified block builder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public final void merge(MinMaxNState other)
}

@Override
public final void popAll(BlockBuilder out)
public final void writeAll(BlockBuilder out)
{
TypedHeap typedHeap = getTypedHeap();
if (typedHeap == null || typedHeap.isEmpty()) {
Expand All @@ -125,9 +125,7 @@ public final void popAll(BlockBuilder out)

BlockBuilder arrayBlockBuilder = out.beginBlockEntry();

size -= typedHeap.getEstimatedSize();
typedHeap.popAllReverse(arrayBlockBuilder);
size += typedHeap.getEstimatedSize();
typedHeap.writeAll(arrayBlockBuilder);

out.closeEntry();
}
Expand Down Expand Up @@ -234,15 +232,15 @@ public final void merge(MinMaxNState other)
}

@Override
public final void popAll(BlockBuilder out)
public final void writeAll(BlockBuilder out)
{
if (typedHeap == null || typedHeap.isEmpty()) {
out.appendNull();
return;
}

BlockBuilder arrayBlockBuilder = out.beginBlockEntry();
typedHeap.popAllReverse(arrayBlockBuilder);
typedHeap.writeAll(arrayBlockBuilder);
out.closeEntry();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ public static void combine(
@OutputFunction("array(E)")
public static void output(@AggregationState("E") MinNState state, BlockBuilder out)
{
state.popAll(out);
state.writeAll(out);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public MinNStateFactory(
"second argument of min_n must be less than or equal to %s; found %s",
MAX_NUMBER_OF_VALUES,
n);

return new TypedHeap(true, compare, elementType, toIntExact(n));
};
deserializer = rowBlock -> TypedHeap.deserialize(true, compare, elementType, rowBlock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import it.unimi.dsi.fastutil.ints.IntArrays;
import org.openjdk.jol.info.ClassLayout;

import java.lang.invoke.MethodHandle;
Expand Down Expand Up @@ -111,41 +112,17 @@ public static TypedHeap deserialize(boolean min, MethodHandle compare, Type elem
return new TypedHeap(min, compare, elementType, capacity, heapBlock.getPositionCount(), heapIndex, heapBlockBuilder);
}

public void popAllReverse(BlockBuilder resultBlockBuilder)
public void writeAll(BlockBuilder resultBlockBuilder)
{
int[] indexes = new int[positionCount];
while (positionCount > 0) {
indexes[positionCount - 1] = heapIndex[0];
positionCount--;
heapIndex[0] = heapIndex[positionCount];
siftDown();
}
System.arraycopy(heapIndex, 0, indexes, 0, positionCount);
IntArrays.quickSort(indexes, (a, b) -> compare(heapBlockBuilder, a, heapBlockBuilder, b));

for (int index : indexes) {
elementType.appendTo(heapBlockBuilder, index, resultBlockBuilder);
}
}

public void popAll(BlockBuilder resultBlockBuilder)
{
while (positionCount > 0) {
pop(resultBlockBuilder);
}
}

public void pop(BlockBuilder resultBlockBuilder)
{
elementType.appendTo(heapBlockBuilder, heapIndex[0], resultBlockBuilder);
remove();
}

private void remove()
{
positionCount--;
heapIndex[0] = heapIndex[positionCount];
siftDown();
}

public void add(Block block, int position)
{
checkArgument(!block.isNull(position));
Expand Down Expand Up @@ -237,18 +214,23 @@ private void compactIfNecessary()
heapBlockBuilder = newHeapBlockBuilder;
}

private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition)
private int compare(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition)
{
try {
long result = (long) compare.invokeExact(leftBlock, leftPosition, rightBlock, rightPosition);
return min ? result < 0 : result > 0;
return (int) (min ? result : -result);
}
catch (Throwable throwable) {
Throwables.throwIfUnchecked(throwable);
throw new RuntimeException(throwable);
}
}

private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition)
{
return compare(leftBlock, leftPosition, rightBlock, rightPosition) < 0;
}

public TypedHeap copy()
{
BlockBuilder heapBlockBuilderCopy = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ private static void test(IntStream inputStream, boolean min, MethodHandle compar
heap.addAll(blockBuilder);

BlockBuilder resultBlockBuilder = BIGINT.createBlockBuilder(null, OUTPUT_SIZE);
heap.popAll(resultBlockBuilder);
heap.writeAll(resultBlockBuilder);

Block resultBlock = resultBlockBuilder.build();
assertEquals(resultBlock.getPositionCount(), OUTPUT_SIZE);
for (int i = 0; i < OUTPUT_SIZE; i++) {
for (int i = OUTPUT_SIZE - 1; i >= 0; i--) {
assertEquals(BIGINT.getLong(resultBlock, i), outputIterator.nextInt());
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.query;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

@TestInstance(PER_CLASS)
public class TestMinMaxNWindow
{
private QueryAssertions assertions;

@BeforeAll
public void init()
{
assertions = new QueryAssertions();
}

@AfterAll
public void teardown()
{
assertions.close();
assertions = null;
}

@Test
public void testMax()
{
assertThat(assertions.query("""
SELECT max(x, 3) OVER () FROM (VALUES 1, 2, 3, 4, 5) t(x)
"""))
.matches("""
VALUES
(ARRAY[5, 4, 3]),
(ARRAY[5, 4, 3]),
(ARRAY[5, 4, 3]),
(ARRAY[5, 4, 3]),
(ARRAY[5, 4, 3])
""");
}

@Test
public void testMin()
{
assertThat(assertions.query("""
SELECT min(x, 3) OVER () FROM (VALUES 1, 2, 3, 4, 5) t(x)
"""))
.matches("""
VALUES
(ARRAY[1, 2, 3]),
(ARRAY[1, 2, 3]),
(ARRAY[1, 2, 3]),
(ARRAY[1, 2, 3]),
(ARRAY[1, 2, 3])
""");
}
}

0 comments on commit 6f3a89b

Please sign in to comment.