Skip to content

Commit

Permalink
Optimize DELTA_BYTE_ARRAY decoder for UUID in parquet
Browse files Browse the repository at this point in the history
Benchmark                        Mode  Cnt   Before           After           Units
BenchmarkUuidColumnReader.read  thrpt   20   22.372 ± 1.186   79.362 ± 5.013  ops/s
  • Loading branch information
raunaqmorarka committed Feb 2, 2023
1 parent 515a52d commit 6ddeb52
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import static io.trino.parquet.ParquetReaderUtils.castToByte;
import static io.trino.parquet.ParquetTimestampUtils.decodeInt96Timestamp;
Expand Down Expand Up @@ -77,42 +74,6 @@ public void skip(int n)
}
}

public static final class UuidApacheParquetValueDecoder
implements ValueDecoder<long[]>
{
private static final VarHandle LONG_ARRAY_HANDLE = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN);

private final ValuesReader delegate;

public UuidApacheParquetValueDecoder(ValuesReader delegate)
{
this.delegate = requireNonNull(delegate, "delegate is null");
}

@Override
public void init(SimpleSliceInputStream input)
{
initialize(input, delegate);
}

@Override
public void read(long[] values, int offset, int length)
{
int endOffset = (offset + length) * 2;
for (int currentOutputOffset = offset * 2; currentOutputOffset < endOffset; currentOutputOffset += 2) {
byte[] data = delegate.readBytes().getBytes();
values[currentOutputOffset] = (long) LONG_ARRAY_HANDLE.get(data, 0);
values[currentOutputOffset + 1] = (long) LONG_ARRAY_HANDLE.get(data, Long.BYTES);
}
}

@Override
public void skip(int n)
{
delegate.skip(n);
}
}

public static final class Int96ApacheParquetValueDecoder
implements ValueDecoder<Int96Buffer>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,40 @@ public static ValueDecoder<byte[]> getShortDecimalToByteDecoder(ParquetEncoding
return new LongToByteTransformDecoder(getShortDecimalDecoder(encoding, field));
}

public static ValueDecoder<long[]> getDeltaUuidDecoder(ParquetEncoding encoding)
{
checkArgument(encoding.equals(DELTA_BYTE_ARRAY), "encoding %s is not DELTA_BYTE_ARRAY", encoding);
ValueDecoder<BinaryBuffer> delegate = new BinaryDeltaByteArrayDecoder();
return new ValueDecoder<>()
{
@Override
public void init(SimpleSliceInputStream input)
{
delegate.init(input);
}

@Override
public void read(long[] values, int offset, int length)
{
BinaryBuffer buffer = new BinaryBuffer(length);
delegate.read(buffer, 0, length);
SimpleSliceInputStream binaryInput = new SimpleSliceInputStream(buffer.asSlice());

int endOffset = (offset + length) * 2;
for (int outputOffset = offset * 2; outputOffset < endOffset; outputOffset += 2) {
values[outputOffset] = binaryInput.readLong();
values[outputOffset + 1] = binaryInput.readLong();
}
}

@Override
public void skip(int n)
{
delegate.skip(n);
}
};
}

private static class LongToIntTransformDecoder
implements ValueDecoder<int[]>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import static io.trino.parquet.ValuesType.VALUES;
import static io.trino.parquet.reader.decoders.ApacheParquetValueDecoders.BooleanApacheParquetValueDecoder;
import static io.trino.parquet.reader.decoders.ApacheParquetValueDecoders.Int96ApacheParquetValueDecoder;
import static io.trino.parquet.reader.decoders.ApacheParquetValueDecoders.UuidApacheParquetValueDecoder;
import static io.trino.parquet.reader.decoders.DeltaBinaryPackedDecoders.DeltaBinaryPackedByteDecoder;
import static io.trino.parquet.reader.decoders.DeltaBinaryPackedDecoders.DeltaBinaryPackedIntDecoder;
import static io.trino.parquet.reader.decoders.DeltaBinaryPackedDecoders.DeltaBinaryPackedLongDecoder;
Expand All @@ -57,6 +56,7 @@
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getBinaryShortDecimalDecoder;
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getDeltaFixedWidthLongDecimalDecoder;
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getDeltaFixedWidthShortDecimalDecoder;
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getDeltaUuidDecoder;
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getInt32ToLongDecoder;
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getInt64ToByteDecoder;
import static io.trino.parquet.reader.decoders.TransformingValueDecoders.getInt64ToIntDecoder;
Expand Down Expand Up @@ -120,8 +120,7 @@ public static ValueDecoder<long[]> getUuidDecoder(ParquetEncoding encoding, Prim
{
return switch (encoding) {
case PLAIN -> new UuidPlainValueDecoder();
case DELTA_BYTE_ARRAY ->
new UuidApacheParquetValueDecoder(getApacheParquetReader(encoding, field));
case DELTA_BYTE_ARRAY -> getDeltaUuidDecoder(encoding);
default -> throw wrongEncoding(encoding, field);
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.parquet.reader;

import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.parquet.ParquetEncoding;
import io.trino.parquet.PrimitiveField;
import io.trino.spi.type.UuidType;
import org.apache.parquet.bytes.HeapByteBufferAllocator;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.values.ValuesWriter;
import org.apache.parquet.column.values.deltastrings.DeltaByteArrayWriter;
import org.apache.parquet.column.values.plain.FixedLenByteArrayPlainValuesWriter;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Types;
import org.openjdk.jmh.annotations.Param;

import java.util.UUID;

import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.parquet.ParquetEncoding.DELTA_BYTE_ARRAY;
import static io.trino.parquet.ParquetEncoding.PLAIN;
import static java.lang.String.format;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;

public class BenchmarkUuidColumnReader
extends AbstractColumnReaderBenchmark<long[]>
{
private static final int LENGTH = 2 * SIZE_OF_LONG;

@Param({
"PLAIN",
"DELTA_BYTE_ARRAY",
})
public ParquetEncoding encoding;

@Override
protected PrimitiveField createPrimitiveField()
{
PrimitiveType parquetType = Types.optional(FIXED_LEN_BYTE_ARRAY)
.length(LENGTH)
.as(LogicalTypeAnnotation.uuidType())
.named("name");
return new PrimitiveField(
UuidType.UUID,
true,
new ColumnDescriptor(new String[] {"test"}, parquetType, 0, 0),
0);
}

@Override
protected ValuesWriter createValuesWriter(int bufferSize)
{
if (encoding.equals(PLAIN)) {
return new FixedLenByteArrayPlainValuesWriter(LENGTH, bufferSize, bufferSize, HeapByteBufferAllocator.getInstance());
}
else if (encoding.equals(DELTA_BYTE_ARRAY)) {
return new DeltaByteArrayWriter(bufferSize, bufferSize, HeapByteBufferAllocator.getInstance());
}
throw new UnsupportedOperationException(format("encoding %s is not supported", encoding));
}

@Override
protected void writeValue(ValuesWriter writer, long[] batch, int index)
{
Slice slice = Slices.wrappedLongArray(batch, index * 2, 2);
writer.writeBytes(Binary.fromConstantByteArray(slice.getBytes()));
}

@Override
protected long[] generateDataBatch(int size)
{
long[] batch = new long[size * 2];
for (int i = 0; i < size; i++) {
UUID uuid = UUID.randomUUID();
batch[i * 2] = uuid.getMostSignificantBits();
batch[(i * 2) + 1] = uuid.getLeastSignificantBits();
}
return batch;
}

public static void main(String[] args)
throws Exception
{
run(BenchmarkUuidColumnReader.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,16 @@ public void testLongDecimalColumnReaderBenchmark()
benchmark.read();
}
}

@Test
public void testUuidColumnReaderBenchmark()
throws IOException
{
for (ParquetEncoding encoding : ImmutableList.of(PLAIN, DELTA_BYTE_ARRAY)) {
BenchmarkUuidColumnReader benchmark = new BenchmarkUuidColumnReader();
benchmark.encoding = encoding;
benchmark.setup();
benchmark.read();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,28 @@
*/
package io.trino.parquet.reader.decoders;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slices;
import io.trino.parquet.ParquetEncoding;
import io.trino.parquet.PrimitiveField;
import io.trino.parquet.reader.SimpleSliceInputStream;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.UuidType;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.column.values.ValuesWriter;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.LogicalTypeAnnotation;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.math.BigInteger;
import java.nio.ByteOrder;
import java.util.OptionalInt;
import java.util.Random;
import java.util.UUID;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -61,7 +68,8 @@ protected Object[][] tests()
generateShortDecimalTests(RLE_DICTIONARY),
generateLongDecimalTests(PLAIN),
generateLongDecimalTests(DELTA_BYTE_ARRAY),
generateLongDecimalTests(RLE_DICTIONARY));
generateLongDecimalTests(RLE_DICTIONARY),
generateUuidTests());
}

private static Object[][] generateShortDecimalTests(ParquetEncoding encoding)
Expand All @@ -87,6 +95,16 @@ private static Object[][] generateLongDecimalTests(ParquetEncoding encoding)
.toArray(Object[][]::new);
}

private static Object[][] generateUuidTests()
{
return ImmutableList.of(PLAIN, DELTA_BYTE_ARRAY).stream()
.map(encoding -> new Object[] {
createUuidTestType(),
encoding,
new UuidInputProvider()})
.toArray(Object[][]::new);
}

private static TestType<long[]> createShortDecimalTestType(int typeLength, int precision)
{
DecimalType decimalType = DecimalType.createDecimalType(precision, 2);
Expand All @@ -111,6 +129,16 @@ private static TestType<long[]> createLongDecimalTestType(int typeLength)
(actual, expected) -> assertThat(actual).isEqualTo(expected));
}

private static TestType<long[]> createUuidTestType()
{
return new TestType<>(
createField(FIXED_LEN_BYTE_ARRAY, OptionalInt.of(16), UuidType.UUID),
ValueDecoders::getUuidDecoder,
UuidApacheParquetValueDecoder::new,
INT128_ADAPTER,
(actual, expected) -> assertThat(actual).isEqualTo(expected));
}

private static InputDataProvider createShortDecimalInputDataProvider(int typeLength, int precision)
{
return new InputDataProvider() {
Expand Down Expand Up @@ -157,6 +185,30 @@ public String toString()
};
}

private static class UuidInputProvider
implements InputDataProvider
{
@Override
public DataBuffer write(ValuesWriter valuesWriter, int dataSize)
{
byte[][] bytes = new byte[dataSize][];
for (int i = 0; i < dataSize; i++) {
UUID uuid = UUID.randomUUID();
bytes[i] = Slices.wrappedLongArray(
uuid.getMostSignificantBits(),
uuid.getLeastSignificantBits())
.getBytes();
}
return writeBytes(valuesWriter, bytes);
}

@Override
public String toString()
{
return "uuid";
}
}

private static DataBuffer writeBytes(ValuesWriter valuesWriter, byte[][] input)
{
for (byte[] value : input) {
Expand Down Expand Up @@ -249,4 +301,40 @@ public void skip(int n)
delegate.skip(n);
}
}

private static final class UuidApacheParquetValueDecoder
implements ValueDecoder<long[]>
{
private static final VarHandle LONG_ARRAY_HANDLE = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN);

private final ValuesReader delegate;

private UuidApacheParquetValueDecoder(ValuesReader delegate)
{
this.delegate = requireNonNull(delegate, "delegate is null");
}

@Override
public void init(SimpleSliceInputStream input)
{
initialize(input, delegate);
}

@Override
public void read(long[] values, int offset, int length)
{
int endOffset = (offset + length) * 2;
for (int currentOutputOffset = offset * 2; currentOutputOffset < endOffset; currentOutputOffset += 2) {
byte[] data = delegate.readBytes().getBytes();
values[currentOutputOffset] = (long) LONG_ARRAY_HANDLE.get(data, 0);
values[currentOutputOffset + 1] = (long) LONG_ARRAY_HANDLE.get(data, Long.BYTES);
}
}

@Override
public void skip(int n)
{
delegate.skip(n);
}
}
}

0 comments on commit 6ddeb52

Please sign in to comment.