Skip to content

Commit

Permalink
Optimize decimal state serializers for small value case
Browse files Browse the repository at this point in the history
Given that many decimal aggregations (sum, avg) stay
in the long range, aggregation state serializer can
be optimized for this case, limiting the number of
bytes per position significantly (3-4X) at the cost of
small cpu overhead during serialization and deserialization.
  • Loading branch information
lukasz-stec authored and sopel39 committed Aug 31, 2022
1 parent aee2179 commit 182b44e
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;

import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class LongDecimalWithOverflowAndLongStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowAndLongState>
{
private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + Int128.SIZE;

@Override
public Type getSerializedType()
{
Expand All @@ -42,7 +39,27 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(count, overflow, decimal[offset], decimal[offset + 1]));
long[] buffer = new long[4];
long high = decimal[offset];
long low = decimal[offset + 1];

buffer[0] = low;
buffer[1] = high;
// if high = 0, the count will overwrite it
int countOffset = 1 + (high == 0 ? 0 : 1);
// append count, overflow
buffer[countOffset] = count;
buffer[countOffset + 1] = overflow;

// cases
// high == 0 (countOffset = 1)
// overflow == 0 & count == 1 -> bufferLength = 1
// overflow != 0 || count != 1 -> bufferLength = 3
// high != 0 (countOffset = 2)
// overflow == 0 & count == 1 -> bufferLength = 2
// overflow != 0 || count != 1 -> bufferLength = 4
int bufferLength = countOffset + ((overflow == 0 & count == 1) ? 0 : 2);
VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength));
}
else {
out.appendNull();
Expand All @@ -54,20 +71,33 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt
{
if (!block.isNull(index)) {
Slice slice = VARBINARY.getSlice(block, index);
if (slice.length() != SERIALIZED_SIZE) {
throw new IllegalStateException("Unexpected serialized state size: " + slice.length());
}
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long count = slice.getLong(0);
long overflow = slice.getLong(Long.BYTES);
int sliceLength = slice.length();
long low = slice.getLong(0);
long high = 0;
long overflow = 0;
long count = 1;

state.setLong(count);
switch (sliceLength) {
case 4 * Long.BYTES:
overflow = slice.getLong(Long.BYTES * 3);
count = slice.getLong(Long.BYTES * 2);
// fall through
case 2 * Long.BYTES:
high = slice.getLong(Long.BYTES);
break;
case 3 * Long.BYTES:
overflow = slice.getLong(Long.BYTES * 2);
count = slice.getLong(Long.BYTES);
}

decimal[offset + 1] = low;
decimal[offset] = high;
state.setOverflow(overflow);
state.setLong(count);
state.setNotNull();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
decimal[offset] = slice.getLong(Long.BYTES * 2);
decimal[offset + 1] = slice.getLong(Long.BYTES * 3);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;

import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class LongDecimalWithOverflowStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowState>
{
private static final int SERIALIZED_SIZE = Long.BYTES + Int128.SIZE;

@Override
public Type getSerializedType()
{
Expand All @@ -41,7 +38,18 @@ public void serialize(LongDecimalWithOverflowState state, BlockBuilder out)
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
VARBINARY.writeSlice(out, Slices.wrappedLongArray(overflow, decimal[offset], decimal[offset + 1]));
long[] buffer = new long[3];
long low = decimal[offset + 1];
long high = decimal[offset];
buffer[0] = low;
buffer[1] = high;
buffer[2] = overflow;
// if high == 0 and overflow == 0 we only write low (bufferLength = 1)
// if high != 0 and overflow == 0 we write both low and high (bufferLength = 2)
// if overflow != 0 we write all values (bufferLength = 3)
int decimalsCount = 1 + (high == 0 ? 0 : 1);
int bufferLength = overflow == 0 ? decimalsCount : 3;
VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength));
}
else {
out.appendNull();
Expand All @@ -53,18 +61,26 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta
{
if (!block.isNull(index)) {
Slice slice = VARBINARY.getSlice(block, index);
if (slice.length() != SERIALIZED_SIZE) {
throw new IllegalStateException("Unexpected serialized state size: " + slice.length());
}
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long overflow = slice.getLong(0);
long low = slice.getLong(0);
int sliceLength = slice.length();
long high = 0;
long overflow = 0;

switch (sliceLength) {
case 3 * Long.BYTES:
overflow = slice.getLong(Long.BYTES * 2);
// fall through
case 2 * Long.BYTES:
high = slice.getLong(Long.BYTES);
}

decimal[offset + 1] = low;
decimal[offset] = high;
state.setOverflow(overflow);
state.setNotNull();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
decimal[offset] = slice.getLong(Long.BYTES);
decimal[offset + 1] = slice.getLong(Long.BYTES * 2);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.VariableWidthBlockBuilder;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestLongDecimalWithOverflowAndLongStateSerializer
{
private static final LongDecimalWithOverflowAndLongStateFactory STATE_FACTORY = new LongDecimalWithOverflowAndLongStateFactory();

@Test(dataProvider = "input")
public void testSerde(long low, long high, long overflow, long count, int expectedLength)
{
LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState();
state.getDecimalArray()[0] = high;
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setLong(count);
state.setNotNull();

LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength);

assertTrue(outState.isNotNull());
assertEquals(outState.getDecimalArray()[0], high);
assertEquals(outState.getDecimalArray()[1], low);
assertEquals(outState.getOverflow(), overflow);
assertEquals(outState.getLong(), count);
}

@Test
public void testNullSerde()
{
// state is created null
LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState();

LongDecimalWithOverflowAndLongState outState = roundTrip(state, 0);

assertFalse(outState.isNotNull());
}

private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength)
{
LongDecimalWithOverflowAndLongStateSerializer serializer = new LongDecimalWithOverflowAndLongStateSerializer();
BlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0);

serializer.serialize(state, out);

Block serialized = out.build();
assertEquals(serialized.getSliceLength(0), expectedLength * Long.BYTES);
LongDecimalWithOverflowAndLongState outState = STATE_FACTORY.createSingleState();
serializer.deserialize(serialized, 0, outState);
return outState;
}

@DataProvider
public Object[][] input()
{
return new Object[][] {
{3, 0, 0, 1, 1},
{3, 5, 0, 1, 2},
{3, 5, 7, 1, 4},
{3, 0, 0, 2, 3},
{3, 5, 0, 2, 4},
{3, 5, 7, 2, 4},
{3, 0, 7, 1, 3},
{3, 0, 7, 2, 3},
{0, 0, 0, 1, 1},
{0, 5, 0, 1, 2},
{0, 5, 7, 1, 4},
{0, 0, 0, 2, 3},
{0, 5, 0, 2, 4},
{0, 5, 7, 2, 4},
{0, 0, 7, 1, 3},
{0, 0, 7, 2, 3}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.VariableWidthBlockBuilder;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestLongDecimalWithOverflowStateSerializer
{
private static final LongDecimalWithOverflowStateFactory STATE_FACTORY = new LongDecimalWithOverflowStateFactory();

@Test(dataProvider = "input")
public void testSerde(long low, long high, long overflow, int expectedLength)
{
LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState();
state.getDecimalArray()[0] = high;
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setNotNull();

LongDecimalWithOverflowState outState = roundTrip(state, expectedLength);

assertTrue(outState.isNotNull());
assertEquals(outState.getDecimalArray()[0], high);
assertEquals(outState.getDecimalArray()[1], low);
assertEquals(outState.getOverflow(), overflow);
}

@Test
public void testNullSerde()
{
// state is created null
LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState();

LongDecimalWithOverflowState outState = roundTrip(state, 0);

assertFalse(outState.isNotNull());
}

private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength)
{
LongDecimalWithOverflowStateSerializer serializer = new LongDecimalWithOverflowStateSerializer();
BlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0);

serializer.serialize(state, out);

Block serialized = out.build();
assertEquals(serialized.getSliceLength(0), expectedLength * Long.BYTES);
LongDecimalWithOverflowState outState = STATE_FACTORY.createSingleState();
serializer.deserialize(serialized, 0, outState);
return outState;
}

@DataProvider
public Object[][] input()
{
return new Object[][] {
{3, 0, 0, 1},
{3, 5, 0, 2},
{3, 5, 7, 3},
{3, 0, 7, 3},
{0, 0, 0, 1},
{0, 5, 0, 2},
{0, 5, 7, 3},
{0, 0, 7, 3}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1400,4 +1400,16 @@ public void testApproxMostFrequentWithStringGroupBy()
assertEquals(actual1.getMaterializedRows().get(2).getFields().get(0), "c");
assertEquals(actual1.getMaterializedRows().get(2).getFields().get(1), ImmutableMap.of("C", 2L));
}

@Test
public void testLongDecimalAggregations()
{
assertQuery("""
SELECT avg(value_big), sum(value_big), avg(value_small), sum(value_small)
FROM (
SELECT orderkey as id, CAST(power(2, 65) as DECIMAL(38, 0)) as value_big, CAST(1 as DECIMAL(38, 0)) as value_small
FROM orders
LIMIT 10)
GROUP BY id""");
}
}

0 comments on commit 182b44e

Please sign in to comment.