From 23de11f7d312866101fbf3b0465703a1736cffef Mon Sep 17 00:00:00 2001 From: Ying Su Date: Fri, 26 Oct 2018 00:37:00 -0700 Subject: [PATCH] Lazily build hashtable for MapBlock Presto builds hashtable for MapBlocks eagerly when constructing the MapBlock even it's not needed in the query. Building a hashtable could take up to 40% CPU of the scan cost on a map column. This commit defers the hashtable build to the time it's needed in SeekKey(). Note that we only do this to the MapBlock, not the MapBlockBuilder to avoid complex synchronization problems. The MapBlockBuilder will always build the hashtable. As the result MergingPageOutput and PartitionOutputOperator will still rebuild the hashtables when needed. The measurements shows there will be less than 10% pages for MergingPageOutput to build the hashtables. We will have a seperate PR to improve PartitionOutput and avoid rebuilding the pages so as to avoid hashtable rebuilding. Simple select checsum queries show over 40% CPU gain: Test | After | Before | Improvement select 2 map columns checksum | 11.69d | 20.06d | 42% Select 1 map column checksum | 9.67d | 17.73d | 45% --- .../presto/spi/block/AbstractMapBlock.java | 79 +++++++++---- .../facebook/presto/spi/block/MapBlock.java | 108 +++++++++++++----- .../presto/spi/block/MapBlockBuilder.java | 58 +++++++--- .../presto/spi/block/MapBlockEncoding.java | 30 +++-- .../presto/spi/block/SingleMapBlock.java | 99 ++++++++-------- .../spi/block/SingleMapBlockEncoding.java | 48 ++++++-- .../com/facebook/presto/spi/type/MapType.java | 4 +- 7 files changed, 296 insertions(+), 130 deletions(-) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java index 87acaf4c59f55..8b76c672fec11 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java @@ -16,6 +16,8 @@ import com.facebook.presto.spi.type.Type; +import javax.annotation.Nullable; + import java.lang.invoke.MethodHandle; import java.util.Arrays; import java.util.Optional; @@ -36,20 +38,23 @@ public abstract class AbstractMapBlock protected final Type keyType; protected final MethodHandle keyNativeHashCode; protected final MethodHandle keyBlockNativeEquals; + protected final MethodHandle keyBlockHashCode; - public AbstractMapBlock(Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals) + public AbstractMapBlock(Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals, MethodHandle keyBlockHashCode) { this.keyType = requireNonNull(keyType, "keyType is null"); // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block this.keyNativeHashCode = keyNativeHashCode; // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block this.keyBlockNativeEquals = keyBlockNativeEquals; + this.keyBlockHashCode = requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); } protected abstract Block getRawKeyBlock(); protected abstract Block getRawValueBlock(); + @Nullable protected abstract int[] getHashTables(); /** @@ -66,6 +71,8 @@ public AbstractMapBlock(Type keyType, MethodHandle keyNativeHashCode, MethodHand protected abstract boolean[] getMapIsNull(); + protected abstract void ensureHashTableLoaded(); + int getOffset(int position) { return getOffsets()[position + getOffsetBase()]; @@ -108,21 +115,35 @@ public Block copyPositions(int[] positions, int offset, int length) } int[] hashTable = getHashTables(); - int[] newHashTable = new int[newOffsets[newOffsets.length - 1] * HASH_MULTIPLIER]; - int newHashIndex = 0; - for (int i = offset; i < offset + length; ++i) { - int position = positions[i]; - int entriesStartOffset = getOffset(position); - int entriesEndOffset = getOffset(position + 1); - for (int hashIndex = entriesStartOffset * HASH_MULTIPLIER; hashIndex < entriesEndOffset * HASH_MULTIPLIER; hashIndex++) { - newHashTable[newHashIndex] = hashTable[hashIndex]; - newHashIndex++; + int[] newHashTable = null; + if (hashTable != null) { + newHashTable = new int[newOffsets[newOffsets.length - 1] * HASH_MULTIPLIER]; + int newHashIndex = 0; + for (int i = offset; i < offset + length; ++i) { + int position = positions[i]; + int entriesStartOffset = getOffset(position); + int entriesEndOffset = getOffset(position + 1); + for (int hashIndex = entriesStartOffset * HASH_MULTIPLIER; hashIndex < entriesEndOffset * HASH_MULTIPLIER; hashIndex++) { + newHashTable[newHashIndex] = hashTable[hashIndex]; + newHashIndex++; + } } } Block newKeys = getRawKeyBlock().copyPositions(entriesPositions.elements(), 0, entriesPositions.size()); Block newValues = getRawValueBlock().copyPositions(entriesPositions.elements(), 0, entriesPositions.size()); - return createMapBlockInternal(0, length, Optional.of(newMapIsNull), newOffsets, newKeys, newValues, newHashTable, keyType, keyBlockNativeEquals, keyNativeHashCode); + return createMapBlockInternal( + 0, + length, + Optional.of(newMapIsNull), + newOffsets, + newKeys, + newValues, + Optional.ofNullable(newHashTable), + keyType, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode); } @Override @@ -138,10 +159,11 @@ public Block getRegion(int position, int length) getOffsets(), getRawKeyBlock(), getRawValueBlock(), - getHashTables(), + Optional.ofNullable(getHashTables()), keyType, keyBlockNativeEquals, - keyNativeHashCode); + keyNativeHashCode, + keyBlockHashCode); } @Override @@ -174,7 +196,12 @@ public Block copyRegion(int position, int length) int[] newOffsets = compactOffsets(getOffsets(), position + getOffsetBase(), length); boolean[] mapIsNull = getMapIsNull(); boolean[] newMapIsNull = mapIsNull == null ? null : compactArray(mapIsNull, position + getOffsetBase(), length); - int[] newHashTable = compactArray(getHashTables(), startValueOffset * HASH_MULTIPLIER, (endValueOffset - startValueOffset) * HASH_MULTIPLIER); + + int[] hashTables = getHashTables(); + int[] newHashTable = null; + if (hashTables != null) { + newHashTable = compactArray(hashTables, startValueOffset * HASH_MULTIPLIER, (endValueOffset - startValueOffset) * HASH_MULTIPLIER); + } if (newKeys == getRawKeyBlock() && newValues == getRawValueBlock() && newOffsets == getOffsets() && newMapIsNull == mapIsNull && newHashTable == getHashTables()) { return this; @@ -186,10 +213,11 @@ public Block copyRegion(int position, int length) newOffsets, newKeys, newValues, - newHashTable, + Optional.ofNullable(newHashTable), keyType, keyBlockNativeEquals, - keyNativeHashCode); + keyNativeHashCode, + keyBlockHashCode); } @Override @@ -205,12 +233,7 @@ public T getObject(int position, Class clazz) return clazz.cast(new SingleMapBlock( startEntryOffset * 2, (endEntryOffset - startEntryOffset) * 2, - getRawKeyBlock(), - getRawValueBlock(), - getHashTables(), - keyType, - keyNativeHashCode, - keyBlockNativeEquals)); + this)); } @Override @@ -230,7 +253,12 @@ public Block getSingleValueBlock(int position) int valueLength = endValueOffset - startValueOffset; Block newKeys = getRawKeyBlock().copyRegion(startValueOffset, valueLength); Block newValues = getRawValueBlock().copyRegion(startValueOffset, valueLength); - int[] newHashTable = Arrays.copyOfRange(getHashTables(), startValueOffset * HASH_MULTIPLIER, endValueOffset * HASH_MULTIPLIER); + + int[] hashTables = getHashTables(); + int[] newHashTable = null; + if (hashTables != null) { + newHashTable = Arrays.copyOfRange(hashTables, startValueOffset * HASH_MULTIPLIER, endValueOffset * HASH_MULTIPLIER); + } return createMapBlockInternal( 0, @@ -239,10 +267,11 @@ public Block getSingleValueBlock(int position) new int[] {0, valueLength}, newKeys, newValues, - newHashTable, + Optional.ofNullable(newHashTable), keyType, keyBlockNativeEquals, - keyNativeHashCode); + keyNativeHashCode, + keyBlockHashCode); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java index 0a51354f1b46c..53a832ba5432e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java @@ -27,6 +27,7 @@ import static com.facebook.presto.spi.block.MapBlockBuilder.buildHashTable; import static io.airlift.slice.SizeOf.sizeOf; +import static io.airlift.slice.SizeOf.sizeOfIntArray; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -42,7 +43,7 @@ public class MapBlock private final int[] offsets; private final Block keyBlock; private final Block valueBlock; - private final int[] hashTables; // hash to location in map; + private volatile int[] hashTables; // hash to location in map. Writes to the field is protected by "this" monitor. private volatile long sizeInBytes; private final long retainedSizeInBytes; @@ -68,20 +69,6 @@ public static MapBlock fromKeyValueBlock( validateConstructorArguments(0, offsets.length - 1, mapIsNull.orElse(null), offsets, keyBlock, valueBlock, mapType.getKeyType(), keyBlockNativeEquals, keyNativeHashCode); int mapCount = offsets.length - 1; - int elementCount = keyBlock.getPositionCount(); - int[] hashTables = new int[elementCount * HASH_MULTIPLIER]; - Arrays.fill(hashTables, -1); - for (int i = 0; i < mapCount; i++) { - int keyOffset = offsets[i]; - int keyCount = offsets[i + 1] - keyOffset; - if (keyCount < 0) { - throw new IllegalArgumentException(format("Offset is not monotonically ascending. offsets[%s]=%s, offsets[%s]=%s", i, offsets[i], i + 1, offsets[i + 1])); - } - if (mapIsNull.isPresent() && mapIsNull.get()[i] && keyCount != 0) { - throw new IllegalArgumentException("A null map must have zero entries"); - } - buildHashTable(keyBlock, keyOffset, keyCount, keyBlockHashCode, hashTables, keyOffset * HASH_MULTIPLIER, keyCount * HASH_MULTIPLIER); - } return createMapBlockInternal( 0, @@ -90,10 +77,11 @@ public static MapBlock fromKeyValueBlock( offsets, keyBlock, valueBlock, - hashTables, + Optional.empty(), mapType.getKeyType(), keyBlockNativeEquals, - keyNativeHashCode); + keyNativeHashCode, + keyBlockHashCode); } /** @@ -112,13 +100,25 @@ public static MapBlock createMapBlockInternal( int[] offsets, Block keyBlock, Block valueBlock, - int[] hashTables, + Optional hashTables, Type keyType, MethodHandle keyBlockNativeEquals, - MethodHandle keyNativeHashCode) + MethodHandle keyNativeHashCode, + MethodHandle keyBlockHashCode) { validateConstructorArguments(startOffset, positionCount, mapIsNull.orElse(null), offsets, keyBlock, valueBlock, keyType, keyBlockNativeEquals, keyNativeHashCode); - return new MapBlock(startOffset, positionCount, mapIsNull.orElse(null), offsets, keyBlock, valueBlock, hashTables, keyType, keyBlockNativeEquals, keyNativeHashCode); + return new MapBlock( + startOffset, + positionCount, + mapIsNull.orElse(null), + offsets, + keyBlock, + valueBlock, + hashTables.orElse(null), + keyType, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode); } private static void validateConstructorArguments( @@ -171,15 +171,15 @@ private MapBlock( int[] offsets, Block keyBlock, Block valueBlock, - int[] hashTables, + @Nullable int[] hashTables, Type keyType, MethodHandle keyBlockNativeEquals, - MethodHandle keyNativeHashCode) + MethodHandle keyNativeHashCode, + MethodHandle keyBlockHashCode) { - super(keyType, keyNativeHashCode, keyBlockNativeEquals); + super(keyType, keyNativeHashCode, keyBlockNativeEquals, keyBlockHashCode); - requireNonNull(hashTables, "hashTables is null"); - if (hashTables.length < keyBlock.getPositionCount() * HASH_MULTIPLIER) { + if (hashTables != null && hashTables.length < keyBlock.getPositionCount() * HASH_MULTIPLIER) { throw new IllegalArgumentException(format("keyBlock/valueBlock size does not match hash table size: %s %s", keyBlock.getPositionCount(), hashTables.length)); } @@ -192,7 +192,16 @@ private MapBlock( this.hashTables = hashTables; this.sizeInBytes = -1; - this.retainedSizeInBytes = INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables); + + // We will add the hashtable size to the retained size even if it's not built yet. This could be overestimating + // but is necessary to avoid reliability issues. Currently the memory counting framework only pull the retained + // size once for each operator so updating in the middle of the processing would not work. + this.retainedSizeInBytes = INSTANCE_SIZE + + keyBlock.getRetainedSizeInBytes() + + valueBlock.getRetainedSizeInBytes() + + sizeOf(offsets) + + sizeOf(mapIsNull) + + sizeOfIntArray(keyBlock.getPositionCount() * HASH_MULTIPLIER); // hashtable size if it was built } @Override @@ -303,9 +312,52 @@ public Block getLoadedBlock() offsets, keyBlock, loadedValueBlock, - hashTables, + Optional.ofNullable(hashTables), keyType, keyBlockNativeEquals, - keyNativeHashCode); + keyNativeHashCode, + keyBlockHashCode); + } + + @Override + protected void ensureHashTableLoaded() + { + if (this.hashTables != null) { + return; + } + + // This can only happen for MapBlock, not MapBlockBuilder because the latter always has non-null hashtables + synchronized (this) { + if (this.hashTables != null) { + return; + } + + int[] offsets = getOffsets(); + int elementCount = getRawKeyBlock().getPositionCount(); + int mapCount = getPositionCount(); + boolean[] mapIsNull = getMapIsNull(); + + int[] hashTables = new int[elementCount * HASH_MULTIPLIER]; + Arrays.fill(hashTables, -1); + for (int i = 0; i < mapCount; i++) { + int keyOffset = offsets[i]; + int keyCount = offsets[i + 1] - keyOffset; + if (keyCount < 0) { + throw new IllegalArgumentException(format("Offset is not monotonically ascending. offsets[%s]=%s, offsets[%s]=%s", i, offsets[i], i + 1, offsets[i + 1])); + } + if (mapIsNull != null && mapIsNull[i] && keyCount != 0) { + throw new IllegalArgumentException("A null map must have zero entries"); + } + buildHashTable( + getRawKeyBlock(), + keyOffset, + keyCount, + keyBlockHashCode, + hashTables, + keyOffset * HASH_MULTIPLIER, + keyCount * HASH_MULTIPLIER); + } + this.hashTables = hashTables; + } } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java index 5b6ffcfda3188..fba99d6ed9099 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java @@ -39,7 +39,6 @@ public class MapBlockBuilder private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapBlockBuilder.class).instanceSize(); private final MethodHandle keyBlockEquals; - private final MethodHandle keyBlockHashCode; @Nullable private final BlockBuilderStatus blockBuilderStatus; @@ -90,10 +89,9 @@ private MapBlockBuilder( boolean[] mapIsNull, int[] hashTables) { - super(keyType, keyNativeHashCode, keyBlockNativeEquals); + super(keyType, keyNativeHashCode, keyBlockNativeEquals, keyBlockHashCode); this.keyBlockEquals = requireNonNull(keyBlockEquals, "keyBlockEquals is null"); - this.keyBlockHashCode = requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); this.blockBuilderStatus = blockBuilderStatus; this.positionCount = 0; @@ -157,7 +155,12 @@ public long getSizeInBytes() @Override public long getRetainedSizeInBytes() { - long size = INSTANCE_SIZE + keyBlockBuilder.getRetainedSizeInBytes() + valueBlockBuilder.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables); + long size = INSTANCE_SIZE + + keyBlockBuilder.getRetainedSizeInBytes() + + valueBlockBuilder.getRetainedSizeInBytes() + + sizeOf(offsets) + + sizeOf(mapIsNull) + + sizeOf(hashTables); if (blockBuilderStatus != null) { size += BlockBuilderStatus.INSTANCE_SIZE; } @@ -199,7 +202,14 @@ public BlockBuilder closeEntry() int previousAggregatedEntryCount = offsets[positionCount - 1]; int aggregatedEntryCount = offsets[positionCount]; int entryCount = aggregatedEntryCount - previousAggregatedEntryCount; - buildHashTable(keyBlockBuilder, previousAggregatedEntryCount, entryCount, keyBlockHashCode, hashTables, previousAggregatedEntryCount * HASH_MULTIPLIER, entryCount * HASH_MULTIPLIER); + buildHashTable( + keyBlockBuilder, + previousAggregatedEntryCount, + entryCount, + keyBlockHashCode, + hashTables, + previousAggregatedEntryCount * HASH_MULTIPLIER, + entryCount * HASH_MULTIPLIER); return this; } @@ -236,7 +246,7 @@ public BlockBuilder closeEntryStrict() return this; } - private BlockBuilder closeEntry(int[] providedHashTable, int providedHashTableOffset) + private BlockBuilder closeEntry(@Nullable int[] providedHashTable, int providedHashTableOffset) { if (!currentEntryOpened) { throw new IllegalStateException("Expected entry to be opened but was closed"); @@ -246,14 +256,29 @@ private BlockBuilder closeEntry(int[] providedHashTable, int providedHashTableOf currentEntryOpened = false; ensureHashTableSize(); + int previousAggregatedEntryCount = offsets[positionCount - 1]; + int aggregatedEntryCount = offsets[positionCount]; - // Directly copy instead of building hashtable - int hashTableOffset = offsets[positionCount - 1] * HASH_MULTIPLIER; - int hashTableSize = (offsets[positionCount] - offsets[positionCount - 1]) * HASH_MULTIPLIER; - for (int i = 0; i < hashTableSize; i++) { - hashTables[hashTableOffset + i] = providedHashTable[providedHashTableOffset + i]; + if (providedHashTable != null) { + // Directly copy instead of building hashtable if providedHashTable is not null + int hashTableOffset = previousAggregatedEntryCount * HASH_MULTIPLIER; + int hashTableSize = (aggregatedEntryCount - previousAggregatedEntryCount) * HASH_MULTIPLIER; + for (int i = 0; i < hashTableSize; i++) { + hashTables[hashTableOffset + i] = providedHashTable[providedHashTableOffset + i]; + } + } + else { + // Build hash table for this map entry. + int entryCount = aggregatedEntryCount - previousAggregatedEntryCount; + buildHashTable( + keyBlockBuilder, + previousAggregatedEntryCount, + entryCount, + keyBlockHashCode, + hashTables, + previousAggregatedEntryCount * HASH_MULTIPLIER, + entryCount * HASH_MULTIPLIER); } - return this; } @@ -311,9 +336,11 @@ public Block build() offsets, keyBlockBuilder.build(), valueBlockBuilder.build(), - Arrays.copyOf(hashTables, offsets[positionCount] * HASH_MULTIPLIER), + Optional.of(Arrays.copyOf(hashTables, offsets[positionCount] * HASH_MULTIPLIER)), keyType, - keyBlockNativeEquals, keyNativeHashCode); + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode); } @Override @@ -410,6 +437,9 @@ public BlockBuilder newBlockBuilderLike(BlockBuilderStatus blockBuilderStatus) newNegativeOneFilledArray(newSize * HASH_MULTIPLIER)); } + @Override + protected void ensureHashTableLoaded() {} + private static int[] newNegativeOneFilledArray(int size) { int[] hashTable = new int[size]; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java index cc0b43537b653..d03904f5af3ba 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java @@ -64,8 +64,15 @@ public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceO blockEncodingSerde.writeBlock(sliceOutput, mapBlock.getRawKeyBlock().getRegion(entriesStartOffset, entriesEndOffset - entriesStartOffset)); blockEncodingSerde.writeBlock(sliceOutput, mapBlock.getRawValueBlock().getRegion(entriesStartOffset, entriesEndOffset - entriesStartOffset)); - sliceOutput.appendInt((entriesEndOffset - entriesStartOffset) * HASH_MULTIPLIER); - sliceOutput.writeBytes(wrappedIntArray(hashTable, entriesStartOffset * HASH_MULTIPLIER, (entriesEndOffset - entriesStartOffset) * HASH_MULTIPLIER)); + if (hashTable != null) { + int hashTableLength = (entriesEndOffset - entriesStartOffset) * HASH_MULTIPLIER; + sliceOutput.appendInt(hashTableLength); // hashtable length + sliceOutput.writeBytes(wrappedIntArray(hashTable, entriesStartOffset * HASH_MULTIPLIER, hashTableLength)); + } + else { + // if the hashTable is null, we write the length -1 + sliceOutput.appendInt(-1); // hashtable length + } sliceOutput.appendInt(positionCount); for (int position = 0; position < positionCount + 1; position++) { @@ -82,18 +89,27 @@ public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceIn Block keyBlock = blockEncodingSerde.readBlock(sliceInput); Block valueBlock = blockEncodingSerde.readBlock(sliceInput); - int[] hashTable = new int[sliceInput.readInt()]; - sliceInput.readBytes(wrappedIntArray(hashTable)); + int hashTableLength = sliceInput.readInt(); + int[] hashTable = null; + if (hashTableLength >= 0) { + hashTable = new int[hashTableLength]; + sliceInput.readBytes(wrappedIntArray(hashTable)); + } + + if (keyBlock.getPositionCount() != valueBlock.getPositionCount()) { + throw new IllegalArgumentException( + format("Deserialized MapBlock violates invariants: key %d, value %d", keyBlock.getPositionCount(), valueBlock.getPositionCount())); + } - if (keyBlock.getPositionCount() != valueBlock.getPositionCount() || keyBlock.getPositionCount() * HASH_MULTIPLIER != hashTable.length) { + if (hashTable != null && keyBlock.getPositionCount() * HASH_MULTIPLIER != hashTable.length) { throw new IllegalArgumentException( - format("Deserialized MapBlock violates invariants: key %d, value %d, hash %d", keyBlock.getPositionCount(), valueBlock.getPositionCount(), hashTable.length)); + format("Deserialized MapBlock violates invariants: expected hashtable size %d, actual hashtable size %d", keyBlock.getPositionCount() * HASH_MULTIPLIER, hashTable.length)); } int positionCount = sliceInput.readInt(); int[] offsets = new int[positionCount + 1]; sliceInput.readBytes(wrappedIntArray(offsets)); Optional mapIsNull = EncoderUtil.decodeNullBits(sliceInput, positionCount); - return MapType.createMapBlockInternal(typeManager, keyType, 0, positionCount, mapIsNull, offsets, keyBlock, valueBlock, hashTable); + return MapType.createMapBlockInternal(typeManager, keyType, 0, positionCount, mapIsNull, offsets, keyBlock, valueBlock, Optional.ofNullable(hashTable)); } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java index e4aeabc870592..72762b5ba934c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java @@ -19,7 +19,6 @@ import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; -import java.lang.invoke.MethodHandle; import java.util.function.BiConsumer; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -35,24 +34,14 @@ public class SingleMapBlock private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMapBlock.class).instanceSize(); private final int offset; - private final int positionCount; - private final Block keyBlock; - private final Block valueBlock; - private final int[] hashTable; - final Type keyType; - private final MethodHandle keyNativeHashCode; - private final MethodHandle keyBlockNativeEquals; - - SingleMapBlock(int offset, int positionCount, Block keyBlock, Block valueBlock, int[] hashTable, Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals) + private final int positionCount; // The number of keys in this single map * 2 + private final AbstractMapBlock mapBlock; + + SingleMapBlock(int offset, int positionCount, AbstractMapBlock mapBlock) { this.offset = offset; this.positionCount = positionCount; - this.keyBlock = keyBlock; - this.valueBlock = valueBlock; - this.hashTable = hashTable; - this.keyType = keyType; - this.keyNativeHashCode = keyNativeHashCode; - this.keyBlockNativeEquals = keyBlockNativeEquals; + this.mapBlock = mapBlock; } @Override @@ -64,23 +53,23 @@ public int getPositionCount() @Override public long getSizeInBytes() { - return keyBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + - valueBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + + return mapBlock.getRawKeyBlock().getRegionSizeInBytes(offset / 2, positionCount / 2) + + mapBlock.getRawValueBlock().getRegionSizeInBytes(offset / 2, positionCount / 2) + sizeOfIntArray(positionCount / 2 * HASH_MULTIPLIER); } @Override public long getRetainedSizeInBytes() { - return INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(hashTable); + return INSTANCE_SIZE + mapBlock.getRawKeyBlock().getRetainedSizeInBytes() + mapBlock.getRawValueBlock().getRetainedSizeInBytes() + sizeOf(mapBlock.getHashTables()); } @Override public void retainedBytesForEachPart(BiConsumer consumer) { - consumer.accept(keyBlock, keyBlock.getRetainedSizeInBytes()); - consumer.accept(valueBlock, valueBlock.getRetainedSizeInBytes()); - consumer.accept(hashTable, sizeOf(hashTable)); + consumer.accept(mapBlock.getRawKeyBlock(), mapBlock.getRawKeyBlock().getRetainedSizeInBytes()); + consumer.accept(mapBlock.getRawValueBlock(), mapBlock.getRawValueBlock().getRetainedSizeInBytes()); + consumer.accept(mapBlock.getHashTables(), sizeOf(mapBlock.getHashTables())); consumer.accept(this, (long) INSTANCE_SIZE); } @@ -99,13 +88,13 @@ public int getOffset() @Override Block getRawKeyBlock() { - return keyBlock; + return mapBlock.getRawKeyBlock(); } @Override Block getRawValueBlock() { - return valueBlock; + return mapBlock.getRawValueBlock(); } @Override @@ -117,29 +106,29 @@ public String toString() @Override public Block getLoadedBlock() { - if (keyBlock != keyBlock.getLoadedBlock()) { + if (mapBlock.getRawKeyBlock() != mapBlock.getRawKeyBlock().getLoadedBlock()) { // keyBlock has to be loaded since MapBlock constructs hash table eagerly. throw new IllegalStateException(); } - Block loadedValueBlock = valueBlock.getLoadedBlock(); - if (loadedValueBlock == valueBlock) { + Block loadedValueBlock = mapBlock.getRawValueBlock().getLoadedBlock(); + if (loadedValueBlock == mapBlock.getRawValueBlock()) { return this; } return new SingleMapBlock( offset, positionCount, - keyBlock, - loadedValueBlock, - hashTable, - keyType, - keyNativeHashCode, - keyBlockNativeEquals); + mapBlock); } int[] getHashTable() { - return hashTable; + return mapBlock.getHashTables(); + } + + Type getKeyType() + { + return mapBlock.keyType; } /** @@ -151,9 +140,12 @@ public int seekKey(Object nativeValue) return -1; } + mapBlock.ensureHashTableLoaded(); + int[] hashTable = mapBlock.getHashTables(); + long hashCode; try { - hashCode = (long) keyNativeHashCode.invoke(nativeValue); + hashCode = (long) mapBlock.keyNativeHashCode.invoke(nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -170,7 +162,7 @@ public int seekKey(Object nativeValue) Boolean match; try { // assuming maps with indeterminate keys are not supported - match = (Boolean) keyBlockNativeEquals.invoke(keyBlock, offset / 2 + keyPosition, nativeValue); + match = (Boolean) mapBlock.keyBlockNativeEquals.invoke(mapBlock.getRawKeyBlock(), offset / 2 + keyPosition, nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -195,9 +187,12 @@ public int seekKeyExact(long nativeValue) return -1; } + mapBlock.ensureHashTableLoaded(); + int[] hashTable = mapBlock.getHashTables(); + long hashCode; try { - hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + hashCode = (long) mapBlock.keyNativeHashCode.invokeExact(nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -214,7 +209,7 @@ public int seekKeyExact(long nativeValue) Boolean match; try { // assuming maps with indeterminate keys are not supported - match = (Boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + match = (Boolean) mapBlock.keyBlockNativeEquals.invokeExact(mapBlock.getRawKeyBlock(), offset / 2 + keyPosition, nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -236,9 +231,12 @@ public int seekKeyExact(boolean nativeValue) return -1; } + mapBlock.ensureHashTableLoaded(); + int[] hashTable = mapBlock.getHashTables(); + long hashCode; try { - hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + hashCode = (long) mapBlock.keyNativeHashCode.invokeExact(nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -255,7 +253,7 @@ public int seekKeyExact(boolean nativeValue) Boolean match; try { // assuming maps with indeterminate keys are not supported - match = (Boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + match = (Boolean) mapBlock.keyBlockNativeEquals.invokeExact(mapBlock.getRawKeyBlock(), offset / 2 + keyPosition, nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -277,9 +275,12 @@ public int seekKeyExact(double nativeValue) return -1; } + mapBlock.ensureHashTableLoaded(); + int[] hashTable = mapBlock.getHashTables(); + long hashCode; try { - hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + hashCode = (long) mapBlock.keyNativeHashCode.invokeExact(nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -296,7 +297,7 @@ public int seekKeyExact(double nativeValue) Boolean match; try { // assuming maps with indeterminate keys are not supported - match = (Boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + match = (Boolean) mapBlock.keyBlockNativeEquals.invokeExact(mapBlock.getRawKeyBlock(), offset / 2 + keyPosition, nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -318,9 +319,12 @@ public int seekKeyExact(Slice nativeValue) return -1; } + mapBlock.ensureHashTableLoaded(); + int[] hashTable = mapBlock.getHashTables(); + long hashCode; try { - hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + hashCode = (long) mapBlock.keyNativeHashCode.invokeExact(nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -337,7 +341,7 @@ public int seekKeyExact(Slice nativeValue) Boolean match; try { // assuming maps with indeterminate keys are not supported - match = (Boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + match = (Boolean) mapBlock.keyBlockNativeEquals.invokeExact(mapBlock.getRawKeyBlock(), offset / 2 + keyPosition, nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -359,9 +363,12 @@ public int seekKeyExact(Block nativeValue) return -1; } + mapBlock.ensureHashTableLoaded(); + int[] hashTable = mapBlock.getHashTables(); + long hashCode; try { - hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + hashCode = (long) mapBlock.keyNativeHashCode.invokeExact(nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); @@ -378,7 +385,7 @@ public int seekKeyExact(Block nativeValue) Boolean match; try { // assuming maps with indeterminate keys are not supported - match = (Boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + match = (Boolean) mapBlock.keyBlockNativeEquals.invokeExact(mapBlock.getRawKeyBlock(), offset / 2 + keyPosition, nativeValue); } catch (Throwable throwable) { throw handleThrowable(throwable); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java index 0b28dd4b728f3..6eee30661e0fb 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java @@ -22,6 +22,7 @@ import io.airlift.slice.SliceOutput; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.spi.block.AbstractMapBlock.HASH_MULTIPLIER; import static com.facebook.presto.spi.block.MethodHandleUtil.compose; @@ -54,15 +55,23 @@ public String getName() public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { SingleMapBlock singleMapBlock = (SingleMapBlock) block; - TypeSerde.writeType(sliceOutput, singleMapBlock.keyType); + TypeSerde.writeType(sliceOutput, singleMapBlock.getKeyType()); int offset = singleMapBlock.getOffset(); int positionCount = singleMapBlock.getPositionCount(); blockEncodingSerde.writeBlock(sliceOutput, singleMapBlock.getRawKeyBlock().getRegion(offset / 2, positionCount / 2)); blockEncodingSerde.writeBlock(sliceOutput, singleMapBlock.getRawValueBlock().getRegion(offset / 2, positionCount / 2)); int[] hashTable = singleMapBlock.getHashTable(); - sliceOutput.appendInt(positionCount / 2 * HASH_MULTIPLIER); - sliceOutput.writeBytes(wrappedIntArray(hashTable, offset / 2 * HASH_MULTIPLIER, positionCount / 2 * HASH_MULTIPLIER)); + + if (hashTable != null) { + int hashTableLength = positionCount / 2 * HASH_MULTIPLIER; + sliceOutput.appendInt(hashTableLength); // hashtable length + sliceOutput.writeBytes(wrappedIntArray(hashTable, offset / 2 * HASH_MULTIPLIER, hashTableLength)); + } + else { + // if the hashTable is null, we write the length -1 + sliceOutput.appendInt(-1); + } } @Override @@ -72,18 +81,41 @@ public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceIn MethodHandle keyNativeEquals = typeManager.resolveOperator(OperatorType.EQUAL, asList(keyType, keyType)); MethodHandle keyBlockNativeEquals = compose(keyNativeEquals, nativeValueGetter(keyType)); MethodHandle keyNativeHashCode = typeManager.resolveOperator(OperatorType.HASH_CODE, singletonList(keyType)); + MethodHandle keyBlockHashCode = compose(keyNativeHashCode, nativeValueGetter(keyType)); Block keyBlock = blockEncodingSerde.readBlock(sliceInput); Block valueBlock = blockEncodingSerde.readBlock(sliceInput); - int[] hashTable = new int[sliceInput.readInt()]; - sliceInput.readBytes(wrappedIntArray(hashTable)); + int hashTableLength = sliceInput.readInt(); + int[] hashTable = null; + if (hashTableLength >= 0) { + hashTable = new int[hashTableLength]; + sliceInput.readBytes(wrappedIntArray(hashTable)); + } + + if (keyBlock.getPositionCount() != valueBlock.getPositionCount()) { + throw new IllegalArgumentException( + format("Deserialized SingleMapBlock violates invariants: key %d, value %d", keyBlock.getPositionCount(), valueBlock.getPositionCount())); + } - if (keyBlock.getPositionCount() != valueBlock.getPositionCount() || keyBlock.getPositionCount() * HASH_MULTIPLIER != hashTable.length) { + if (hashTable != null && keyBlock.getPositionCount() * HASH_MULTIPLIER != hashTable.length) { throw new IllegalArgumentException( - format("Deserialized SingleMapBlock violates invariants: key %d, value %d, hash %d", keyBlock.getPositionCount(), valueBlock.getPositionCount(), hashTable.length)); + format("Deserialized SingleMapBlock violates invariants: expected hashtable size %d, actual hashtable size %d", keyBlock.getPositionCount() * HASH_MULTIPLIER, hashTable.length)); } - return new SingleMapBlock(0, keyBlock.getPositionCount() * 2, keyBlock, valueBlock, hashTable, keyType, keyNativeHashCode, keyBlockNativeEquals); + MapBlock mapBlock = MapBlock.createMapBlockInternal( + 0, + 1, + Optional.empty(), + new int[] {0, keyBlock.getPositionCount()}, + keyBlock, + valueBlock, + Optional.ofNullable(hashTable), + keyType, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode); + + return new SingleMapBlock(0, keyBlock.getPositionCount() * 2, mapBlock); } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java index 8a6e4b940f687..2da884b04aae6 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java @@ -278,11 +278,11 @@ public static Block createMapBlockInternal( int[] offsets, Block keyBlock, Block valueBlock, - int[] hashTables) + Optional hashTables) { // TypeManager caches types. Therefore, it is important that we go through it instead of coming up with the MethodHandles directly. // BIGINT is chosen arbitrarily here. Any type will do. MapType mapType = (MapType) typeManager.getType(new TypeSignature(StandardTypes.MAP, TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(BIGINT.getTypeSignature()))); - return MapBlock.createMapBlockInternal(startOffset, positionCount, mapIsNull, offsets, keyBlock, valueBlock, hashTables, keyType, mapType.keyBlockNativeEquals, mapType.keyNativeHashCode); + return MapBlock.createMapBlockInternal(startOffset, positionCount, mapIsNull, offsets, keyBlock, valueBlock, hashTables, keyType, mapType.keyBlockNativeEquals, mapType.keyNativeHashCode, mapType.keyBlockHashCode); } }