Skip to content

Commit

Permalink
Account for memory usage of parquet column readers
Browse files Browse the repository at this point in the history
Added accounting of memory usage to batched column readers.
This takes into account the dictionary retained for decoding
dictionary data pages and the extra memory usage of decompressed
values from data pages.
  • Loading branch information
raunaqmorarka committed Jan 3, 2023
1 parent 49c0762 commit 91515ea
Show file tree
Hide file tree
Showing 22 changed files with 410 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package io.trino.parquet.reader;

import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.parquet.PrimitiveField;
import io.trino.parquet.reader.decoders.TransformingValueDecoders;
import io.trino.parquet.reader.decoders.ValueDecoders;
Expand Down Expand Up @@ -78,81 +80,84 @@ public final class ColumnReaderFactory
{
private ColumnReaderFactory() {}

public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, boolean useBatchedColumnReaders)
public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, AggregatedMemoryContext aggregatedMemoryContext, boolean useBatchedColumnReaders)
{
Type type = field.getType();
PrimitiveTypeName primitiveType = field.getDescriptor().getPrimitiveType().getPrimitiveTypeName();
LogicalTypeAnnotation annotation = field.getDescriptor().getPrimitiveType().getLogicalTypeAnnotation();
LocalMemoryContext memoryContext = aggregatedMemoryContext.newLocalMemoryContext(ColumnReader.class.getSimpleName());
if (useBatchedColumnReaders && field.getDescriptor().getPath().length == 1) {
if (BOOLEAN.equals(type) && primitiveType == PrimitiveTypeName.BOOLEAN) {
return new FlatColumnReader<>(field, ValueDecoders::getBooleanDecoder, BOOLEAN_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getBooleanDecoder, BOOLEAN_ADAPTER, memoryContext);
}
if (TINYINT.equals(type) && primitiveType == INT32) {
if (isIntegerAnnotation(annotation)) {
return new FlatColumnReader<>(field, ValueDecoders::getByteDecoder, BYTE_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getByteDecoder, BYTE_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (SMALLINT.equals(type) && primitiveType == INT32) {
if (isIntegerAnnotation(annotation)) {
return new FlatColumnReader<>(field, ValueDecoders::getShortDecoder, SHORT_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getShortDecoder, SHORT_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (DATE.equals(type) && primitiveType == INT32) {
if (annotation == null || annotation instanceof DateLogicalTypeAnnotation) {
return new FlatColumnReader<>(field, ValueDecoders::getIntDecoder, INT_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getIntDecoder, INT_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (type instanceof AbstractIntType && primitiveType == INT32) {
if (isIntegerAnnotation(annotation)) {
return new FlatColumnReader<>(field, ValueDecoders::getIntDecoder, INT_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getIntDecoder, INT_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (type instanceof AbstractLongType && primitiveType == INT32) {
if (isIntegerAnnotation(annotation)) {
return new FlatColumnReader<>(field, ValueDecoders::getIntToLongDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getIntToLongDecoder, LONG_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (type instanceof TimeType && primitiveType == INT64) {
if (annotation instanceof TimeLogicalTypeAnnotation timeAnnotation && timeAnnotation.getUnit() == MICROS) {
return new FlatColumnReader<>(field, TransformingValueDecoders::getTimeMicrosDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, TransformingValueDecoders::getTimeMicrosDecoder, LONG_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (type instanceof AbstractLongType && primitiveType == INT64) {
if (BIGINT.equals(type) && annotation instanceof TimestampLogicalTypeAnnotation) {
return new FlatColumnReader<>(field, ValueDecoders::getLongDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getLongDecoder, LONG_ADAPTER, memoryContext);
}
if (isIntegerAnnotation(annotation)) {
return new FlatColumnReader<>(field, ValueDecoders::getLongDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getLongDecoder, LONG_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
if (REAL.equals(type) && primitiveType == FLOAT) {
return new FlatColumnReader<>(field, ValueDecoders::getRealDecoder, INT_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getRealDecoder, INT_ADAPTER, memoryContext);
}
if (DOUBLE.equals(type) && primitiveType == PrimitiveTypeName.DOUBLE) {
return new FlatColumnReader<>(field, ValueDecoders::getDoubleDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getDoubleDecoder, LONG_ADAPTER, memoryContext);
}
if (type instanceof TimestampType timestampType && primitiveType == INT96) {
if (timestampType.isShort()) {
return new FlatColumnReader<>(
field,
(encoding, primitiveField) -> getInt96ToShortTimestampDecoder(encoding, primitiveField, timeZone),
LONG_ADAPTER);
LONG_ADAPTER,
memoryContext);
}
return new FlatColumnReader<>(
field,
(encoding, primitiveField) -> getInt96ToLongTimestampDecoder(encoding, primitiveField, timeZone),
INT96_ADAPTER);
INT96_ADAPTER,
memoryContext);
}
if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType && primitiveType == INT96) {
if (timestampWithTimeZoneType.isShort()) {
return new FlatColumnReader<>(field, TransformingValueDecoders::getInt96ToShortTimestampWithTimeZoneDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, TransformingValueDecoders::getInt96ToShortTimestampWithTimeZoneDecoder, LONG_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
Expand All @@ -162,15 +167,15 @@ public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, b
}
if (timestampType.isShort()) {
return switch (timestampAnnotation.getUnit()) {
case MILLIS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMillsToShortTimestampDecoder, LONG_ADAPTER);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToShortTimestampDecoder, LONG_ADAPTER);
case NANOS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampNanosToShortTimestampDecoder, LONG_ADAPTER);
case MILLIS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMillsToShortTimestampDecoder, LONG_ADAPTER, memoryContext);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToShortTimestampDecoder, LONG_ADAPTER, memoryContext);
case NANOS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampNanosToShortTimestampDecoder, LONG_ADAPTER, memoryContext);
};
}
return switch (timestampAnnotation.getUnit()) {
case MILLIS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMillisToLongTimestampDecoder, INT96_ADAPTER);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToLongTimestampDecoder, INT96_ADAPTER);
case NANOS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampNanosToLongTimestampDecoder, INT96_ADAPTER);
case MILLIS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMillisToLongTimestampDecoder, INT96_ADAPTER, memoryContext);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToLongTimestampDecoder, INT96_ADAPTER, memoryContext);
case NANOS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampNanosToLongTimestampDecoder, INT96_ADAPTER, memoryContext);
};
}
if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType && primitiveType == INT64) {
Expand All @@ -179,42 +184,42 @@ public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, b
}
if (timestampWithTimeZoneType.isShort()) {
return switch (timestampAnnotation.getUnit()) {
case MILLIS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMillsToShortTimestampWithTimeZoneDecoder, LONG_ADAPTER);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToShortTimestampWithTimeZoneDecoder, LONG_ADAPTER);
case MILLIS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMillsToShortTimestampWithTimeZoneDecoder, LONG_ADAPTER, memoryContext);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToShortTimestampWithTimeZoneDecoder, LONG_ADAPTER, memoryContext);
case NANOS -> throw unsupportedException(type, field);
};
}
return switch (timestampAnnotation.getUnit()) {
case MILLIS, NANOS -> throw unsupportedException(type, field);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToLongTimestampWithTimeZoneDecoder, INT96_ADAPTER);
case MICROS -> new FlatColumnReader<>(field, TransformingValueDecoders::getInt64TimestampMicrosToLongTimestampWithTimeZoneDecoder, INT96_ADAPTER, memoryContext);
};
}
if (type instanceof DecimalType decimalType && decimalType.isShort()
&& (primitiveType == INT32 || primitiveType == INT64 || primitiveType == FIXED_LEN_BYTE_ARRAY)) {
if (annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation && !isDecimalRescaled(decimalAnnotation, decimalType)) {
return new FlatColumnReader<>(field, ValueDecoders::getShortDecimalDecoder, LONG_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getShortDecimalDecoder, LONG_ADAPTER, memoryContext);
}
}
if (type instanceof DecimalType decimalType && !decimalType.isShort()
&& (primitiveType == BINARY || primitiveType == FIXED_LEN_BYTE_ARRAY)) {
if (annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation && !isDecimalRescaled(decimalAnnotation, decimalType)) {
return new FlatColumnReader<>(field, ValueDecoders::getLongDecimalDecoder, INT128_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getLongDecimalDecoder, INT128_ADAPTER, memoryContext);
}
}
if (type instanceof VarcharType varcharType && !varcharType.isUnbounded() && primitiveType == BINARY) {
return new FlatColumnReader<>(field, ValueDecoders::getBoundedVarcharBinaryDecoder, BINARY_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getBoundedVarcharBinaryDecoder, BINARY_ADAPTER, memoryContext);
}
if (type instanceof CharType && primitiveType == BINARY) {
return new FlatColumnReader<>(field, ValueDecoders::getCharBinaryDecoder, BINARY_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getCharBinaryDecoder, BINARY_ADAPTER, memoryContext);
}
if (type instanceof AbstractVariableWidthType && primitiveType == BINARY) {
return new FlatColumnReader<>(field, ValueDecoders::getBinaryDecoder, BINARY_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getBinaryDecoder, BINARY_ADAPTER, memoryContext);
}
if (UUID.equals(type) && primitiveType == FIXED_LEN_BYTE_ARRAY) {
// Iceberg 0.11.1 writes UUID as FIXED_LEN_BYTE_ARRAY without logical type annotation (see https://github.com/apache/iceberg/pull/2913)
// To support such files, we bet on the logical type to be UUID based on the Trino UUID type check.
if (annotation == null || isLogicalUuid(annotation)) {
return new FlatColumnReader<>(field, ValueDecoders::getUuidDecoder, INT128_ADAPTER);
return new FlatColumnReader<>(field, ValueDecoders::getUuidDecoder, INT128_ADAPTER, memoryContext);
}
throw unsupportedException(type, field);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ public void skipNextPage()
compressedPages.next();
}

public boolean arePagesCompressed()
{
return codec != CompressionCodecName.UNCOMPRESSED;
}

private void verifyDictionaryPageRead()
{
if (hasDictionaryPage) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ public class ParquetReader
private final ParquetReaderOptions options;
private int maxBatchSize;

private AggregatedMemoryContext currentRowGroupMemoryContext;
private final Map<ChunkKey, ChunkedInputStream> chunkReaders;
private final List<Optional<ColumnIndexStore>> columnIndexStore;
private final Optional<ParquetWriteValidation> writeValidation;
Expand Down Expand Up @@ -170,6 +171,7 @@ public ParquetReader(
this.dataSource = requireNonNull(dataSource, "dataSource is null");
this.timeZone = requireNonNull(timeZone, "timeZone is null");
this.memoryContext = requireNonNull(memoryContext, "memoryContext is null");
this.currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext();
this.options = requireNonNull(options, "options is null");
this.maxBatchSize = options.getMaxReadBlockRowCount();
this.columnReaders = new HashMap<>();
Expand Down Expand Up @@ -242,6 +244,10 @@ public ParquetReader(
public void close()
throws IOException
{
// Release memory usage from column readers
columnReaders.clear();
currentRowGroupMemoryContext.close();

for (ChunkedInputStream chunkedInputStream : chunkReaders.values()) {
chunkedInputStream.close();
}
Expand Down Expand Up @@ -299,6 +305,8 @@ private int nextBatch()
private boolean advanceToNextRowGroup()
throws IOException
{
currentRowGroupMemoryContext.close();
currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext();
freeCurrentRowGroupBuffers();

if (currentRowGroup >= 0 && rowGroupStatisticsValidation.isPresent()) {
Expand Down Expand Up @@ -472,7 +480,9 @@ private ColumnChunkMetaData getColumnChunkMetaData(BlockMetaData blockMetaData,
private void initializeColumnReaders()
{
for (PrimitiveField field : primitiveFields) {
columnReaders.put(field.getId(), ColumnReaderFactory.create(field, timeZone, options.useBatchColumnReaders()));
columnReaders.put(
field.getId(),
ColumnReaderFactory.create(field, timeZone, currentRowGroupMemoryContext, options.useBatchColumnReaders()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.parquet.reader.decoders;

import io.airlift.slice.Slice;
import io.trino.parquet.DictionaryPage;
import io.trino.parquet.ParquetEncoding;
import io.trino.parquet.PrimitiveField;
Expand Down Expand Up @@ -259,9 +260,10 @@ public static <T> DictionaryDecoder<T> getDictionaryDecoder(
{
int size = dictionaryPage.getDictionarySize();
T dictionary = columnAdapter.createBuffer(size);
plainValuesDecoder.init(new SimpleSliceInputStream(dictionaryPage.getSlice()));
Slice dictionarySlice = dictionaryPage.getSlice();
plainValuesDecoder.init(new SimpleSliceInputStream(dictionarySlice));
plainValuesDecoder.read(dictionary, 0, size);
return new DictionaryDecoder<>(dictionary, columnAdapter);
return new DictionaryDecoder<>(dictionary, columnAdapter, columnAdapter.getSizeInBytes(dictionary));
}

private static ValuesReader getApacheParquetReader(ParquetEncoding encoding, PrimitiveField field)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.List;

import static io.airlift.slice.SizeOf.sizeOf;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -107,4 +108,13 @@ public int getValueCount()
{
return offsets.length - 1;
}

public long getRetainedSize()
{
long chunksSizeInBytes = 0;
for (Slice slice : chunks) {
chunksSizeInBytes += slice.getRetainedSize();
}
return sizeOf(offsets) + chunksSizeInBytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,10 @@ public void decodeDictionaryIds(BinaryBuffer values, int offset, int length, int

values.addChunk(Slices.wrappedBuffer(outputChunk));
}

@Override
public long getSizeInBytes(BinaryBuffer values)
{
return values.getRetainedSize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Optional;

import static io.airlift.slice.SizeOf.sizeOf;

public class BooleanColumnAdapter
implements ColumnAdapter<byte[]>
{
Expand Down Expand Up @@ -54,4 +56,10 @@ public void decodeDictionaryIds(byte[] values, int offset, int length, int[] ids
values[offset + i] = dictionary[ids[i]];
}
}

@Override
public long getSizeInBytes(byte[] values)
{
return sizeOf(values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Optional;

import static io.airlift.slice.SizeOf.sizeOf;

public class ByteColumnAdapter
implements ColumnAdapter<byte[]>
{
Expand Down Expand Up @@ -54,4 +56,10 @@ public void decodeDictionaryIds(byte[] values, int offset, int length, int[] ids
values[offset + i] = dictionary[ids[i]];
}
}

@Override
public long getSizeInBytes(byte[] values)
{
return sizeOf(values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ default void unpackNullValues(BufferType source, BufferType destination, boolean
}

void decodeDictionaryIds(BufferType values, int offset, int length, int[] ids, BufferType dictionary);

long getSizeInBytes(BufferType values);
}
Loading

0 comments on commit 91515ea

Please sign in to comment.