Skip to content

Commit

Permalink
Change temporary data storage in MultiChannelGroupByHash
Browse files Browse the repository at this point in the history
Temporary output data used to be kept in a list of variable-size pages. That
required using a two-level address to point an exact row. Since the hash table
can only store 2^30 positions, the number of groups can be stored in a 4-byte
int variable, given that the page size is fixed. This way we can get rid of
some data structures used to map between group id and output row position
skipping some calculations and reducing memory footprint from 8+8+4+1 bytes
per hash bucket to 4+1 bytes which improves memory locality.
Since the memory footprint of MultiChannelGroupByHash is now smaller,
tests using testMemoryReservationYield needed to change because
we yield less often due to QueryContext.GUARANTEED_MEMORY threshold
and reserve less memory.
  • Loading branch information
skrzypo987 authored and sopel39 committed May 27, 2022
1 parent 1cf9bc6 commit 6c5eb59
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.array.LongBigArray;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
Expand All @@ -41,8 +40,6 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.trino.operator.SyntheticAddress.decodePosition;
import static io.trino.operator.SyntheticAddress.decodeSliceIndex;
import static io.trino.operator.SyntheticAddress.encodeSyntheticAddress;
import static io.trino.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES;
import static io.trino.spi.type.BigintType.BIGINT;
Expand All @@ -62,6 +59,9 @@ public class MultiChannelGroupByHash
private static final float FILL_RATIO = 0.75f;
// Max (page value count / cumulative dictionary size) to trigger the low cardinality case
private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25;
private static final int VALUES_PAGE_BITS = 14; // 16k positions
private static final int VALUES_PAGE_MAX_ROW_COUNT = 1 << VALUES_PAGE_BITS;
private static final int VALUES_PAGE_MASK = VALUES_PAGE_MAX_ROW_COUNT - 1;

private final List<Type> types;
private final List<Type> hashTypes;
Expand All @@ -80,12 +80,11 @@ public class MultiChannelGroupByHash
private int hashCapacity;
private int maxFill;
private int mask;
private long[] groupAddressByHash;
// Group ids are assigned incrementally. Therefore, since values page size is constant and power of two,
// the group id is also an address (slice index and position within slice) to group row in channelBuilders.
private int[] groupIdsByHash;
private byte[] rawHashByHashPosition;

private final LongBigArray groupAddressByGroupId;

private int nextGroupId;
private DictionaryLookBack dictionaryLookBack;
private long hashCollisions;
Expand Down Expand Up @@ -147,15 +146,9 @@ public MultiChannelGroupByHash(

maxFill = calculateMaxFill(hashCapacity);
mask = hashCapacity - 1;
groupAddressByHash = new long[hashCapacity];
Arrays.fill(groupAddressByHash, -1);

rawHashByHashPosition = new byte[hashCapacity];

groupIdsByHash = new int[hashCapacity];

groupAddressByGroupId = new LongBigArray();
groupAddressByGroupId.ensureCapacity(maxFill);
Arrays.fill(groupIdsByHash, -1);

// This interface is used for actively reserving memory (push model) for rehash.
// The caller can also query memory usage on this object (pull model)
Expand All @@ -165,9 +158,8 @@ public MultiChannelGroupByHash(
@Override
public long getRawHash(int groupId)
{
long address = groupAddressByGroupId.get(groupId);
int blockIndex = decodeSliceIndex(address);
int position = decodePosition(address);
int blockIndex = groupId >> VALUES_PAGE_BITS;
int position = groupId & VALUES_PAGE_MASK;
return hashStrategy.hashPosition(blockIndex, position);
}

Expand All @@ -178,9 +170,7 @@ public long getEstimatedSize()
(sizeOf(channelBuilders.get(0).elements()) * channelBuilders.size()) +
completedPagesMemorySize +
currentPageBuilder.getRetainedSizeInBytes() +
sizeOf(groupAddressByHash) +
sizeOf(groupIdsByHash) +
groupAddressByGroupId.sizeOf() +
sizeOf(rawHashByHashPosition) +
preallocatedMemoryInBytes;
}
Expand Down Expand Up @@ -212,9 +202,8 @@ public int getGroupCount()
@Override
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
{
long address = groupAddressByGroupId.get(groupId);
int blockIndex = decodeSliceIndex(address);
int position = decodePosition(address);
int blockIndex = groupId >> VALUES_PAGE_BITS;
int position = groupId & VALUES_PAGE_MASK;
hashStrategy.appendTo(blockIndex, position, pageBuilder, 0);
}

Expand Down Expand Up @@ -265,8 +254,8 @@ public boolean contains(int position, Page page, int[] hashChannels, long rawHas
int hashPosition = getHashPosition(rawHash, mask);

// look for a slot containing this key
while (groupAddressByHash[hashPosition] != -1) {
if (positionNotDistinctFromCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) {
while (groupIdsByHash[hashPosition] != -1) {
if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) {
// found an existing slot for this key
return true;
}
Expand Down Expand Up @@ -296,8 +285,8 @@ private int putIfAbsent(int position, Page page, long rawHash)

// look for an empty slot or a slot containing this key
int groupId = -1;
while (groupAddressByHash[hashPosition] != -1) {
if (positionNotDistinctFromCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) {
while (groupIdsByHash[hashPosition] != -1) {
if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) {
// found an existing slot for this key
groupId = groupIdsByHash[hashPosition];

Expand Down Expand Up @@ -336,13 +325,11 @@ private int addNewGroup(int hashPosition, int position, Page page, long rawHash)
// record group id in hash
int groupId = nextGroupId++;

groupAddressByHash[hashPosition] = address;
rawHashByHashPosition[hashPosition] = (byte) rawHash;
groupIdsByHash[hashPosition] = groupId;
groupAddressByGroupId.set(groupId, address);

// create new page builder if this page is full
if (currentPageBuilder.isFull()) {
if (currentPageBuilder.getPositionCount() == VALUES_PAGE_MAX_ROW_COUNT) {
startNewPage();
}

Expand All @@ -362,6 +349,7 @@ private void startNewPage()
{
if (currentPageBuilder != null) {
completedPagesMemorySize += currentPageBuilder.getRetainedSizeInBytes();
// TODO: (https://github.com/trinodb/trino/issues/12484) pre-size new PageBuilder to OUTPUT_PAGE_SIZE
currentPageBuilder = currentPageBuilder.newPageBuilderLike();
}
else {
Expand All @@ -382,10 +370,9 @@ private boolean tryRehash()
int newCapacity = toIntExact(newCapacityLong);

// An estimate of how much extra memory is needed before we can go ahead and expand the hash table.
// This includes the new capacity for groupAddressByHash, rawHashByHashPosition, groupIdsByHash, and groupAddressByGroupId as well as the size of the current page
preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Long.BYTES + Integer.BYTES + Byte.BYTES) +
(long) (calculateMaxFill(newCapacity) - maxFill) * Long.BYTES +
currentPageSizeInBytes;
// This includes the new capacity for rawHashByHashPosition, groupIdsByHash as well as the size of the current page
preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Integer.BYTES + Byte.BYTES)
+ currentPageSizeInBytes;
if (!updateMemory.update()) {
// reserved memory but has exceeded the limit
return false;
Expand All @@ -395,67 +382,61 @@ private boolean tryRehash()
expectedHashCollisions += estimateNumberOfHashCollisions(getGroupCount(), hashCapacity);

int newMask = newCapacity - 1;
long[] newKey = new long[newCapacity];
byte[] rawHashes = new byte[newCapacity];
Arrays.fill(newKey, -1);
int[] newValue = new int[newCapacity];
int[] newGroupIdByHash = new int[newCapacity];
Arrays.fill(newGroupIdByHash, -1);

int oldIndex = 0;
for (int groupId = 0; groupId < nextGroupId; groupId++) {
for (int i = 0; i < hashCapacity; i++) {
// seek to the next used slot
while (groupAddressByHash[oldIndex] == -1) {
oldIndex++;
int groupId = groupIdsByHash[i];
if (groupId == -1) {
continue;
}

// get the address for this slot
long address = groupAddressByHash[oldIndex];

long rawHash = hashPosition(address);
long rawHash = hashPosition(groupId);
// find an empty slot for the address
int pos = getHashPosition(rawHash, newMask);
while (newKey[pos] != -1) {
while (newGroupIdByHash[pos] != -1) {
pos = (pos + 1) & newMask;
hashCollisions++;
}

// record the mapping
newKey[pos] = address;
rawHashes[pos] = (byte) rawHash;
newValue[pos] = groupIdsByHash[oldIndex];
oldIndex++;
newGroupIdByHash[pos] = groupId;
}

this.mask = newMask;
this.hashCapacity = newCapacity;
this.maxFill = calculateMaxFill(newCapacity);
this.groupAddressByHash = newKey;
this.rawHashByHashPosition = rawHashes;
this.groupIdsByHash = newValue;
groupAddressByGroupId.ensureCapacity(maxFill);
this.groupIdsByHash = newGroupIdByHash;
return true;
}

private long hashPosition(long sliceAddress)
private long hashPosition(int groupId)
{
int sliceIndex = decodeSliceIndex(sliceAddress);
int position = decodePosition(sliceAddress);
int blockIndex = groupId >> VALUES_PAGE_BITS;
int blockPosition = groupId & VALUES_PAGE_MASK;
if (precomputedHashChannel.isPresent()) {
return getRawHash(sliceIndex, position, precomputedHashChannel.getAsInt());
return getRawHash(blockIndex, blockPosition, precomputedHashChannel.getAsInt());
}
return hashStrategy.hashPosition(sliceIndex, position);
return hashStrategy.hashPosition(blockIndex, blockPosition);
}

private long getRawHash(int sliceIndex, int position, int hashChannel)
{
return channelBuilders.get(hashChannel).get(sliceIndex).getLong(position, 0);
}

private boolean positionNotDistinctFromCurrentRow(long address, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels)
private boolean positionNotDistinctFromCurrentRow(int groupId, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels)
{
if (rawHashByHashPosition[hashPosition] != rawHash) {
return false;
}
return hashStrategy.positionNotDistinctFromRow(decodeSliceIndex(address), decodePosition(address), position, page, hashChannels);
int blockIndex = groupId >> VALUES_PAGE_BITS;
int blockPosition = groupId & VALUES_PAGE_MASK;
return hashStrategy.positionNotDistinctFromRow(blockIndex, blockPosition, position, page, hashChannels);
}

private static int getHashPosition(long rawHash, int mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<
expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes();
}
else {
// groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75
expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes();
// groupIdsByHash, and rawHashByHashPosition double by hashCapacity
expectedReservedExtraBytes = oldCapacity * (long) (Integer.BYTES + Byte.BYTES);
}
assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes);

Expand All @@ -190,10 +190,24 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<

// Assert the estimated reserved memory before rehash is very close to the one after rehash
long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01);
double memoryUsageErrorUpperBound = 1.01;
double memoryUsageError = rehashedMemoryUsage * 1.0 / newMemoryUsage;
if (memoryUsageError > memoryUsageErrorUpperBound) {
// Usually the error is < 1%, but since MultiChannelGroupByHash.getEstimatedSize
// accounts for changes in completedPagesMemorySize, which is increased if new page is
// added by addNewGroup (an even that cannot be predicted as it depends on the number of unique groups
// in the current page being processed), the difference includes size of the added new page.
// Lower bound is 1% lower than normal because additionalMemoryInBytes includes also aggregator state.
assertBetweenInclusive(rehashedMemoryUsage * 1.0 / (newMemoryUsage + additionalMemoryInBytes), 0.98, memoryUsageErrorUpperBound,
"rehashedMemoryUsage " + rehashedMemoryUsage + ", newMemoryUsage: " + newMemoryUsage);
}
else {
assertBetweenInclusive(memoryUsageError, 0.99, memoryUsageErrorUpperBound);
}

// unblocked
assertTrue(operator.needsInput());
assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import java.util.concurrent.ScheduledExecutorService;

import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.testing.Assertions.assertGreaterThan;
import static io.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static io.trino.RowPagesBuilder.rowPagesBuilder;
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys;
Expand Down Expand Up @@ -167,9 +167,9 @@ public void testMemoryReservationYield(Type type)
joinCompiler,
blockTypeOperators);

GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((DistinctLimitOperator) operator).getCapacity(), 1_400_000);
assertGreaterThan(result.getYieldCount(), 5);
assertGreaterThan(result.getMaxReservedBytes(), 20L << 20);
GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((DistinctLimitOperator) operator).getCapacity(), 450_000);
assertGreaterThanOrEqual(result.getYieldCount(), 5);
assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20);
assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 6_000 * 600);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder;
import static io.airlift.testing.Assertions.assertGreaterThan;
import static io.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.airlift.units.DataSize.succinctBytes;
Expand Down Expand Up @@ -410,9 +411,9 @@ public void testMemoryReservationYield(Type type)

// get result with yield; pick a relatively small buffer for aggregator's memory usage
GroupByHashYieldResult result;
result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, this::getHashCapacity, 1_400_000);
assertGreaterThan(result.getYieldCount(), 5);
assertGreaterThan(result.getMaxReservedBytes(), 20L << 20);
result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, this::getHashCapacity, 450_000);
assertGreaterThanOrEqual(result.getYieldCount(), 5);
assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20);

int count = 0;
for (Page page : result.getOutput()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

import static com.google.common.collect.Iterables.concat;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.testing.Assertions.assertGreaterThan;
import static io.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static io.trino.RowPagesBuilder.rowPagesBuilder;
import static io.trino.SessionTestUtils.TEST_SESSION;
Expand Down Expand Up @@ -244,10 +243,10 @@ public void testSemiJoinMemoryReservationYield(Type type)
type,
setBuilderOperatorFactory,
operator -> ((SetBuilderOperator) operator).getCapacity(),
1_400_000);
450_000);

assertGreaterThanOrEqual(result.getYieldCount(), 5);
assertGreaterThan(result.getMaxReservedBytes(), 20L << 20);
assertGreaterThanOrEqual(result.getYieldCount(), 4);
assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 19);
assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.testing.Assertions.assertGreaterThan;
import static io.airlift.testing.Assertions.assertGreaterThanOrEqual;
import static io.airlift.testing.Assertions.assertInstanceOf;
import static io.trino.RowPagesBuilder.rowPagesBuilder;
import static io.trino.SessionTestUtils.TEST_SESSION;
Expand Down Expand Up @@ -176,9 +176,9 @@ public void testMemoryReservationYield(Type type)
OperatorFactory operatorFactory = new MarkDistinctOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(type), ImmutableList.of(0), Optional.of(1), joinCompiler, blockTypeOperators);

// get result with yield; pick a relatively small buffer for partitionRowCount's memory usage
GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((MarkDistinctOperator) operator).getCapacity(), 1_400_000);
assertGreaterThan(result.getYieldCount(), 5);
assertGreaterThan(result.getMaxReservedBytes(), 20L << 20);
GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((MarkDistinctOperator) operator).getCapacity(), 450_000);
assertGreaterThanOrEqual(result.getYieldCount(), 5);
assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20);

int count = 0;
for (Page page : result.getOutput()) {
Expand Down
Loading

0 comments on commit 6c5eb59

Please sign in to comment.