Skip to content

Commit

Permalink
Add support for adaptive partial aggregation
Browse files Browse the repository at this point in the history
When partial aggregation is not effectively reducing cardinality, instead send
raw input rows directly to the final aggregation step.

Port of:
github.com/trinodb/trino/pull/11011
github.com/trinodb/trino/pull/17143

Co-authored-by: Lukasz Stec <lukasz.s.stec@gmail.com>
Co-authored-by: Karol Sobczak <sopel39@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 27, 2023
1 parent a67b145 commit c2312bb
Show file tree
Hide file tree
Showing 12 changed files with 546 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ public final class SystemSessionProperties
public static final String PREFER_PARTIAL_AGGREGATION = "prefer_partial_aggregation";
public static final String PARTIAL_AGGREGATION_STRATEGY = "partial_aggregation_strategy";
public static final String PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD = "partial_aggregation_byte_reduction_threshold";
public static final String ADAPTIVE_PARTIAL_AGGREGATION = "adaptive_partial_aggregation";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number";
public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate";
public static final String MAX_GROUPING_SETS = "max_grouping_sets";
Expand Down Expand Up @@ -960,6 +962,16 @@ public SystemSessionProperties(
"Byte reduction ratio threshold at which to disable partial aggregation",
featuresConfig.getPartialAggregationByteReductionThreshold(),
false),
booleanProperty(
ADAPTIVE_PARTIAL_AGGREGATION,
"Enable adaptive partial aggregation",
featuresConfig.isAdaptivePartialAggregationEnabled(),
false),
doubleProperty(
ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD,
"Rows reduction ratio threshold at which to adaptively disable partial aggregation",
featuresConfig.getAdaptivePartialAggregationRowsReductionRatioThreshold(),
false),
booleanProperty(
OPTIMIZE_TOP_N_ROW_NUMBER,
"Use top N row number optimization",
Expand Down Expand Up @@ -2318,6 +2330,16 @@ public static double getPartialAggregationByteReductionThreshold(Session session
return session.getSystemProperty(PARTIAL_AGGREGATION_BYTE_REDUCTION_THRESHOLD, Double.class);
}

public static boolean isAdaptivePartialAggregationEnabled(Session session)
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION, Boolean.class);
}

public static double getAdaptivePartialAggregationRowsReductionRatioThreshold(Session session)
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD, Double.class);
}

public static boolean isOptimizeTopNRowNumber(Session session)
{
return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,36 @@
*/
package com.facebook.presto.operator;

import javax.annotation.Nullable;

import static java.util.Objects.requireNonNull;

public final class CompletedWork<T>
implements Work<T>
{
@Nullable
private final T result;

public CompletedWork(T value)
{
this.result = requireNonNull(value);
}

/**
* This constructor can be used when the result is computed immediately and we do not need the yield machinery
*/
public CompletedWork()
{
this.result = null;
}

@Override
public boolean process()
{
return true;
}

@Nullable
@Override
public T getResult()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.facebook.presto.operator.aggregation.builder.HashAggregationBuilder;
import com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder;
import com.facebook.presto.operator.aggregation.builder.SpillableHashAggregationBuilder;
import com.facebook.presto.operator.aggregation.partial.PartialAggregationController;
import com.facebook.presto.operator.aggregation.partial.SkipAggregationBuilder;
import com.facebook.presto.operator.scalar.CombineHashFunction;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.plan.AggregationNode.Step;
Expand All @@ -38,11 +40,13 @@
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.stream.Collectors;

import static com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes;
import static com.facebook.presto.sql.planner.PlannerUtils.INITIAL_HASH_VALUE;
import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -78,6 +82,7 @@ public static class HashAggregationOperatorFactory
private final SpillerFactory spillerFactory;
private final JoinCompiler joinCompiler;
private final boolean useSystemMemory;
private final Optional<PartialAggregationController> partialAggregationController;

private boolean closed;

Expand Down Expand Up @@ -112,6 +117,7 @@ public HashAggregationOperatorFactory(
expectedGroups,
maxPartialMemory,
false,
Optional.empty(),
new DataSize(0, MEGABYTE),
new DataSize(0, MEGABYTE),
(types, spillContext, memoryContext) -> {
Expand All @@ -136,6 +142,7 @@ public HashAggregationOperatorFactory(
int expectedGroups,
Optional<DataSize> maxPartialMemory,
boolean spillEnabled,
Optional<PartialAggregationController> partialAggregationController,
DataSize unspillMemoryLimit,
SpillerFactory spillerFactory,
JoinCompiler joinCompiler,
Expand All @@ -155,6 +162,7 @@ public HashAggregationOperatorFactory(
expectedGroups,
maxPartialMemory,
spillEnabled,
partialAggregationController,
unspillMemoryLimit,
DataSize.succinctBytes((long) (unspillMemoryLimit.toBytes() * MERGE_WITH_MEMORY_RATIO)),
spillerFactory,
Expand All @@ -178,6 +186,7 @@ public HashAggregationOperatorFactory(
int expectedGroups,
Optional<DataSize> maxPartialMemory,
boolean spillEnabled,
Optional<PartialAggregationController> partialAggregationController,
DataSize memoryLimitForMerge,
DataSize memoryLimitForMergeWithMemory,
SpillerFactory spillerFactory,
Expand All @@ -198,6 +207,7 @@ public HashAggregationOperatorFactory(
this.expectedGroups = expectedGroups;
this.maxPartialMemory = requireNonNull(maxPartialMemory, "maxPartialMemory is null");
this.spillEnabled = spillEnabled;
this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationController is null");
this.memoryLimitForMerge = requireNonNull(memoryLimitForMerge, "memoryLimitForMerge is null");
this.memoryLimitForMergeWithMemory = requireNonNull(memoryLimitForMergeWithMemory, "memoryLimitForMergeWithMemory is null");
this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null");
Expand Down Expand Up @@ -225,6 +235,7 @@ public Operator createOperator(DriverContext driverContext)
expectedGroups,
maxPartialMemory,
spillEnabled,
partialAggregationController,
memoryLimitForMerge,
memoryLimitForMergeWithMemory,
spillerFactory,
Expand Down Expand Up @@ -257,6 +268,7 @@ public OperatorFactory duplicate()
expectedGroups,
maxPartialMemory,
spillEnabled,
partialAggregationController.map(PartialAggregationController::duplicate),
memoryLimitForMerge,
memoryLimitForMergeWithMemory,
spillerFactory,
Expand All @@ -278,6 +290,7 @@ public OperatorFactory duplicate()
private final int expectedGroups;
private final Optional<DataSize> maxPartialMemory;
private final boolean spillEnabled;
private final Optional<PartialAggregationController> partialAggregationController;
private final DataSize memoryLimitForMerge;
private final DataSize memoryLimitForMergeWithMemory;
private final SpillerFactory spillerFactory;
Expand All @@ -299,6 +312,10 @@ public OperatorFactory duplicate()
// for yield when memory is not available
private Work<?> unfinishedWork;

private long inputBytesProcessed;
private long inputRowsProcessed;
private long uniqueRowsProduced;

public HashAggregationOperator(
OperatorContext operatorContext,
List<Type> groupByTypes,
Expand All @@ -313,6 +330,7 @@ public HashAggregationOperator(
int expectedGroups,
Optional<DataSize> maxPartialMemory,
boolean spillEnabled,
Optional<PartialAggregationController> partialAggregationController,
DataSize memoryLimitForMerge,
DataSize memoryLimitForMergeWithMemory,
SpillerFactory spillerFactory,
Expand All @@ -337,6 +355,9 @@ public HashAggregationOperator(
this.maxPartialMemory = requireNonNull(maxPartialMemory, "maxPartialMemory is null");
this.types = toTypes(groupByTypes, step, accumulatorFactories, hashChannel);
this.spillEnabled = spillEnabled;
this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationController is null");
checkArgument(!partialAggregationController.isPresent() || step.isOutputPartial(),
"partialAggregationController should only be present for partial aggregation");
this.memoryLimitForMerge = requireNonNull(memoryLimitForMerge, "memoryLimitForMerge is null");
this.memoryLimitForMergeWithMemory = requireNonNull(memoryLimitForMergeWithMemory, "memoryLimitForMergeWithMemory is null");
this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null");
Expand Down Expand Up @@ -402,7 +423,10 @@ public void addInput(Page page)
if (unfinishedWork != null && unfinishedWork.process()) {
unfinishedWork = null;
}

aggregationBuilder.updateMemory();
inputBytesProcessed += page.getSizeInBytes();
inputRowsProcessed += page.getPositionCount();
}

@Override
Expand Down Expand Up @@ -470,7 +494,9 @@ public Page getOutput()
return null;
}

return outputPages.getResult();
Page result = outputPages.getResult();
uniqueRowsProduced += result.getPositionCount();
return result;
}

@Override
Expand Down Expand Up @@ -534,6 +560,16 @@ private int findLastSegmentStart(PagesHashStrategy pagesHashStrategy, Page page)

private void closeAggregationBuilder()
{
partialAggregationController.ifPresent(
controller -> controller.onFlush(
inputBytesProcessed,
inputRowsProcessed,
// Empty uniqueRowsProduced indicates to PartialAggregationController that partial agg is disabled
aggregationBuilder instanceof SkipAggregationBuilder ? OptionalLong.empty() : OptionalLong.of(uniqueRowsProduced)));
inputBytesProcessed = 0;
inputRowsProcessed = 0;
uniqueRowsProduced = 0;

outputPages = null;
if (aggregationBuilder != null) {
aggregationBuilder.recordHashCollisions(hashCollisionsCounter);
Expand Down Expand Up @@ -563,7 +599,18 @@ private void initializeAggregationBuilderIfNeeded()
return;
}

if (step.isOutputPartial() || !spillEnabled) {
boolean partialAggregationDisabled = partialAggregationController
.map(PartialAggregationController::isPartialAggregationDisabled)
.orElse(false);

if (step.isOutputPartial() && partialAggregationDisabled) {
aggregationBuilder = new SkipAggregationBuilder(
groupByChannels,
hashChannel,
accumulatorFactories,
operatorContext.localUserMemoryContext());
}
else if (step.isOutputPartial() || !spillEnabled) {
aggregationBuilder = new InMemoryHashAggregationBuilder(
accumulatorFactories,
step,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.partial;

import io.airlift.units.DataSize;

import java.util.OptionalLong;

import static java.util.Objects.requireNonNull;

public class PartialAggregationController
{
/**
* Process enough pages to fill up the partial aggregation buffer, before considering disabling partial aggregation.
* With 16 MB as default partial agg buffer, this means we process at least 24 MB of input data before considering to disable partial agg.
* We use bytes instead of rows as the floor to disable partial aggregation due to issues with file skew when rows are small. We want to make sure
* the partial aggregation buffer is fully utilized before making the decision on disabling partial aggregation.
*/
private static final double DISABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO = 1.5;
/**
* Re-enable partial aggregation periodically, in case later data can be partially aggregated more effectively.
*/
private static final double ENABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO = DISABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO * 200;

private final DataSize maxPartialAggregationMemorySize;
private final double uniqueRowsRatioThreshold;

private volatile boolean partialAggregationDisabled;
private long totalBytesProcessed;
private long totalRowsProcessed;
private long totalUniqueRowsProduced;

public PartialAggregationController(DataSize maxPartialAggregationMemorySize, double uniqueRowsRatioThreshold)
{
this.maxPartialAggregationMemorySize = requireNonNull(maxPartialAggregationMemorySize, "maxPartialMemory is null");
this.uniqueRowsRatioThreshold = uniqueRowsRatioThreshold;
}

public boolean isPartialAggregationDisabled()
{
return partialAggregationDisabled;
}

public synchronized void onFlush(long bytesProcessed, long rowsProcessed, OptionalLong uniqueRowsProduced)
{
if (!partialAggregationDisabled && !uniqueRowsProduced.isPresent()) {
// when partial aggregation has been re-enabled, ignore stats from disabled flushes
return;
}

totalBytesProcessed += bytesProcessed;
totalRowsProcessed += rowsProcessed;
uniqueRowsProduced.ifPresent(value -> totalUniqueRowsProduced += value);

if (!partialAggregationDisabled && shouldDisablePartialAggregation()) {
partialAggregationDisabled = true;
}

if (partialAggregationDisabled && totalBytesProcessed >= maxPartialAggregationMemorySize.toBytes() * ENABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO) {
totalBytesProcessed = 0;
totalRowsProcessed = 0;
totalUniqueRowsProduced = 0;
partialAggregationDisabled = false;
}
}

private boolean shouldDisablePartialAggregation()
{
return totalBytesProcessed >= maxPartialAggregationMemorySize.toBytes() * DISABLE_AGGREGATION_BUFFER_SIZE_TO_INPUT_BYTES_RATIO
&& ((double) totalUniqueRowsProduced / totalRowsProcessed) > uniqueRowsRatioThreshold;
}

public PartialAggregationController duplicate()
{
return new PartialAggregationController(maxPartialAggregationMemorySize, uniqueRowsRatioThreshold);
}
}
Loading

0 comments on commit c2312bb

Please sign in to comment.