Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimate partition memory usage based on previous attempts #11857

Merged
merged 5 commits into from
Apr 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ public final class SystemSessionProperties
public static final String FAULT_TOLERANT_EXECUTION_MAX_TASK_SPLIT_COUNT = "fault_tolerant_execution_max_task_split_count";
public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY = "fault_tolerant_execution_task_memory";
public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_GROWTH_FACTOR = "fault_tolerant_execution_task_memory_growth_factor";
public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE = "fault_tolerant_execution_task_memory_estimation_quantile";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
Expand Down Expand Up @@ -777,6 +778,12 @@ public SystemSessionProperties(
"Factor by which estimated task memory is increased if task execution runs out of memory; value is used allocating nodes for tasks execution",
memoryManagerConfig.getFaultTolerantExecutionTaskMemoryGrowthFactor(),
false),
doubleProperty(
FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE,
"What quantile of memory usage of completed tasks to look at when estimating memory usage for upcoming tasks",
memoryManagerConfig.getFaultTolerantExecutionTaskMemoryEstimationQuantile(),
value -> validateDoubleRange(value, FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE, 0.0, 1.0),
false),
booleanProperty(
ADAPTIVE_PARTIAL_AGGREGATION_ENABLED,
"When enabled, partial aggregation might be adaptively turned off when it does not provide any performance gain",
Expand Down Expand Up @@ -1414,6 +1421,11 @@ public static double getFaultTolerantExecutionTaskMemoryGrowthFactor(Session ses
return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TASK_MEMORY_GROWTH_FACTOR, Double.class);
}

public static double getFaultTolerantExecutionTaskMemoryEstimationQuantile(Session session)
{
return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE, Double.class);
}

public static boolean isAdaptivePartialAggregationEnabled(Session session)
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import io.trino.execution.StateMachine.StateChangeListener;
import io.trino.execution.scheduler.NodeAllocatorService;
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.PartitionMemoryEstimator;
import io.trino.execution.scheduler.PartitionMemoryEstimatorFactory;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.execution.scheduler.SqlQueryScheduler;
import io.trino.execution.scheduler.TaskDescriptorStorage;
Expand Down Expand Up @@ -102,7 +102,7 @@ public class SqlQueryExecution
private final NodePartitioningManager nodePartitioningManager;
private final NodeScheduler nodeScheduler;
private final NodeAllocatorService nodeAllocatorService;
private final PartitionMemoryEstimator partitionMemoryEstimator;
private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory;
private final List<PlanOptimizer> planOptimizers;
private final PlanFragmenter planFragmenter;
private final RemoteTaskFactory remoteTaskFactory;
Expand Down Expand Up @@ -137,7 +137,7 @@ private SqlQueryExecution(
NodePartitioningManager nodePartitioningManager,
NodeScheduler nodeScheduler,
NodeAllocatorService nodeAllocatorService,
PartitionMemoryEstimator partitionMemoryEstimator,
PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory,
List<PlanOptimizer> planOptimizers,
PlanFragmenter planFragmenter,
RemoteTaskFactory remoteTaskFactory,
Expand Down Expand Up @@ -166,7 +166,7 @@ private SqlQueryExecution(
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null");
this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null");
this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null");
this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null");
this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null");
this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null");
this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null");
Expand Down Expand Up @@ -506,7 +506,7 @@ private void planDistribution(PlanRoot plan)
nodePartitioningManager,
nodeScheduler,
nodeAllocatorService,
partitionMemoryEstimator,
partitionMemoryEstimatorFactory,
remoteTaskFactory,
plan.isSummarizeTaskInfos(),
scheduleSplitBatchSize,
Expand Down Expand Up @@ -709,7 +709,7 @@ public static class SqlQueryExecutionFactory
private final NodePartitioningManager nodePartitioningManager;
private final NodeScheduler nodeScheduler;
private final NodeAllocatorService nodeAllocatorService;
private final PartitionMemoryEstimator partitionMemoryEstimator;
private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory;
private final List<PlanOptimizer> planOptimizers;
private final PlanFragmenter planFragmenter;
private final RemoteTaskFactory remoteTaskFactory;
Expand Down Expand Up @@ -737,7 +737,7 @@ public static class SqlQueryExecutionFactory
NodePartitioningManager nodePartitioningManager,
NodeScheduler nodeScheduler,
NodeAllocatorService nodeAllocatorService,
PartitionMemoryEstimator partitionMemoryEstimator,
PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory,
PlanOptimizersFactory planOptimizersFactory,
PlanFragmenter planFragmenter,
RemoteTaskFactory remoteTaskFactory,
Expand Down Expand Up @@ -766,7 +766,7 @@ public static class SqlQueryExecutionFactory
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null");
this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null");
this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null");
this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null");
this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null");
this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null");
Expand Down Expand Up @@ -807,7 +807,7 @@ public QueryExecution createQueryExecution(
nodePartitioningManager,
nodeScheduler,
nodeAllocatorService,
partitionMemoryEstimator,
partitionMemoryEstimatorFactory,
planOptimizers,
planFragmenter,
remoteTaskFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import io.trino.Session;
import io.trino.spi.ErrorCode;

import java.util.Optional;

public class ConstantPartitionMemoryEstimator
implements PartitionMemoryEstimator
{
Expand All @@ -33,4 +35,7 @@ public MemoryRequirements getNextRetryMemoryRequirements(Session session, Memory
{
return previousMemoryRequirements;
}

@Override
public void registerPartitionFinished(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, boolean success, Optional<ErrorCode> errorCode) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,104 @@
*/
package io.trino.execution.scheduler;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import com.google.common.collect.Streams;
import io.airlift.stats.TDigest;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.spi.ErrorCode;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTaskMemoryEstimationQuantile;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTaskMemoryGrowthFactor;
import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY;
import static io.trino.spi.StandardErrorCode.EXCEEDED_LOCAL_MEMORY_LIMIT;

public class ExponentialGrowthPartitionMemoryEstimator
implements PartitionMemoryEstimator
{
private final TDigest memoryUsageDistribution = new TDigest();

@Override
public MemoryRequirements getInitialMemoryRequirements(Session session, DataSize defaultMemoryLimit)
{
return new MemoryRequirements(
defaultMemoryLimit,
Ordering.natural().max(defaultMemoryLimit, getEstimatedMemoryUsage(session)),
false);
}

@Override
public MemoryRequirements getNextRetryMemoryRequirements(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, ErrorCode errorCode)
{
DataSize previousMemory = previousMemoryRequirements.getRequiredMemory();
DataSize baseMemory = Ordering.natural().max(peakMemoryUsage, previousMemory);
if (shouldIncreaseMemory(errorCode)) {
double growthFactor = SystemSessionProperties.getFaultTolerantExecutionTaskMemoryGrowthFactor(session);
return new MemoryRequirements(DataSize.of((long) (baseMemory.toBytes() * growthFactor), DataSize.Unit.BYTE), false);

// start with the maximum of previously used memory and actual usage
DataSize newMemory = Ordering.natural().max(peakMemoryUsage, previousMemory);
if (isOutOfMemoryError(errorCode)) {
// multiply if we hit an oom error
double growthFactor = getFaultTolerantExecutionTaskMemoryGrowthFactor(session);
newMemory = DataSize.of((long) (newMemory.toBytes() * growthFactor), DataSize.Unit.BYTE);
}
return new MemoryRequirements(baseMemory, false);

// if we are still below current estimate for new partition let's bump further
newMemory = Ordering.natural().max(newMemory, getEstimatedMemoryUsage(session));

return new MemoryRequirements(newMemory, false);
}

private boolean shouldIncreaseMemory(ErrorCode errorCode)
private boolean isOutOfMemoryError(ErrorCode errorCode)
{
return EXCEEDED_LOCAL_MEMORY_LIMIT.toErrorCode().equals(errorCode) // too many tasks from single query on a node
|| CLUSTER_OUT_OF_MEMORY.toErrorCode().equals(errorCode); // too many tasks in general on a node
}

@Override
public synchronized void registerPartitionFinished(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, boolean success, Optional<ErrorCode> errorCode)
{
if (success) {
memoryUsageDistribution.add(peakMemoryUsage.toBytes());
}
if (!success && errorCode.isPresent() && isOutOfMemoryError(errorCode.get())) {
double growthFactor = getFaultTolerantExecutionTaskMemoryGrowthFactor(session);
// take previousRequiredBytes into account when registering failure on oom. It is conservative hence safer (and in-line with getNextRetryMemoryRequirements)
long previousRequiredBytes = previousMemoryRequirements.getRequiredMemory().toBytes();
long previousPeakBytes = peakMemoryUsage.toBytes();
memoryUsageDistribution.add(Math.max(previousRequiredBytes, previousPeakBytes) * growthFactor);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we add estimated memory usage to the distribution? If retry succeeds, then the actual usage will be added, right? This seems to skew the metric.

I understand your intention though. If one task consumes large amount of memory, then other tasks may also need large amount of memory. But this will make the stats collection inaccurate, maybe we should explore some other approach instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - this surely is not exact science and I am not sure how well it will work in practice. But intention is exactly what you wrote. If we see that tasks are dying because we gave them too little memory we want to bump initial memory already for new tasks. Not wait until we have one which succeeds (it may take a lot of time till we have one).

I was thinking first about having two separate histograms for successful and unsuccessful tries. And make the one for unsuccessful decaying over time so "newer data is more important" - but I did not come up with reasonable way to merge the data from both, so I implemented the simple (yet I agree not 100% bullet-proof) approach.

Happy to hear suggestions how to improve though :)

BTW: I will add a commit on top with extra debug logging so we can see how it works in practice when testing out queries on cluster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. Actually I guess with some tuning your approach might work well in practice. We can leave it as it is for now.

}
}

private synchronized DataSize getEstimatedMemoryUsage(Session session)
{
double estimationQuantile = getFaultTolerantExecutionTaskMemoryEstimationQuantile(session);
double estimation = memoryUsageDistribution.valueAt(estimationQuantile);
if (Double.isNaN(estimation)) {
return DataSize.ofBytes(0);
}
return DataSize.ofBytes((long) estimation);
}

private String memoryUsageDistributionInfo()
{
List<Double> quantiles = ImmutableList.of(0.01, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 0.99);
List<Double> values;
synchronized (this) {
values = memoryUsageDistribution.valuesAt(quantiles);
}

return Streams.zip(
quantiles.stream(),
values.stream(),
(quantile, value) -> "" + quantile + "=" + value)
.collect(Collectors.joining(", ", "[", "]"));
}

@Override
public String toString()
{
return "memoryUsageDistribution=" + memoryUsageDistributionInfo();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import io.trino.Session;
import io.trino.spi.ErrorCode;

import java.util.Optional;

import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY;
import static io.trino.spi.StandardErrorCode.EXCEEDED_LOCAL_MEMORY_LIMIT;

Expand Down Expand Up @@ -50,4 +52,7 @@ private boolean shouldRescheduleWithFullNode(ErrorCode errorCode)
return EXCEEDED_LOCAL_MEMORY_LIMIT.toErrorCode().equals(errorCode) // too many tasks from single query on a node
|| CLUSTER_OUT_OF_MEMORY.toErrorCode().equals(errorCode); // too many tasks in general on a node
}

@Override
public void registerPartitionFinished(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, boolean success, Optional<ErrorCode> errorCode) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ public synchronized void schedule()
TaskDescriptor taskDescriptor = taskDescriptorOptional.get();

MemoryRequirements memoryRequirements = partitionMemoryRequirements.computeIfAbsent(partition, ignored -> partitionMemoryEstimator.getInitialMemoryRequirements(session, taskDescriptor.getNodeRequirements().getMemory()));
log.debug("Computed initial memory requirements for task from stage %s; requirements=%s; estimator=%s", stage.getStageId(), memoryRequirements, partitionMemoryEstimator);
if (nodeLease == null) {
NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements();
nodeRequirements = nodeRequirements.withMemory(memoryRequirements.getRequiredMemory());
Expand Down Expand Up @@ -515,6 +516,8 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional<ExchangeSinkInstan
int partitionId = taskId.getPartitionId();

if (!finishedPartitions.contains(partitionId) && !closed) {
MemoryRequirements memoryLimits = partitionMemoryRequirements.get(partitionId);
verify(memoryLimits != null);
switch (state) {
case FINISHED:
finishedPartitions.add(partitionId);
Expand All @@ -523,12 +526,15 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional<ExchangeSinkInstan
sinkExchange.get().sinkFinished(exchangeSinkInstanceHandle.get());
}
partitionToRemoteTaskMap.get(partitionId).forEach(RemoteTask::abort);
partitionMemoryEstimator.registerPartitionFinished(session, memoryLimits, taskStatus.getPeakMemoryReservation(), true, Optional.empty());
break;
case CANCELED:
log.debug("Task cancelled: %s", taskId);
// no need for partitionMemoryEstimator.registerPartitionFinished; task cancelled mid-way
break;
case ABORTED:
log.debug("Task aborted: %s", taskId);
// no need for partitionMemoryEstimator.registerPartitionFinished; task aborted mid-way
break;
case FAILED:
ExecutionFailureInfo failureInfo = taskStatus.getFailures().stream()
Expand All @@ -537,6 +543,7 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional<ExchangeSinkInstan
.orElse(toFailure(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")));
log.warn(failureInfo.toException(), "Task failed: %s", taskId);
ErrorCode errorCode = failureInfo.getErrorCode();
partitionMemoryEstimator.registerPartitionFinished(session, memoryLimits, taskStatus.getPeakMemoryReservation(), false, Optional.ofNullable(errorCode));

int taskRemainingAttempts = remainingAttemptsPerTask.getOrDefault(partitionId, maxRetryAttemptsPerTask);
if (remainingRetryAttemptsOverall.get() > 0
Expand All @@ -546,9 +553,8 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional<ExchangeSinkInstan
remainingAttemptsPerTask.put(partitionId, taskRemainingAttempts - 1);

// update memory limits for next attempt
MemoryRequirements memoryLimits = partitionMemoryRequirements.get(partitionId);
verify(memoryLimits != null);
MemoryRequirements newMemoryLimits = partitionMemoryEstimator.getNextRetryMemoryRequirements(session, memoryLimits, taskStatus.getPeakMemoryReservation(), errorCode);
log.debug("Computed next memory requirements for task from stage %s; previous=%s; new=%s; peak=%s; estimator=%s", stage.getStageId(), memoryLimits, newMemoryLimits, taskStatus.getPeakMemoryReservation(), partitionMemoryEstimator);
partitionMemoryRequirements.put(partitionId, newMemoryLimits);

// reschedule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.ErrorCode;

import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;
Expand All @@ -28,6 +29,8 @@ public interface PartitionMemoryEstimator

MemoryRequirements getNextRetryMemoryRequirements(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, ErrorCode errorCode);

void registerPartitionFinished(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, boolean success, Optional<ErrorCode> errorCode);

class MemoryRequirements
{
private final DataSize requiredMemory;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler;

@FunctionalInterface
public interface PartitionMemoryEstimatorFactory
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: why having a factory is beneficial here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want a new instance for each stage, so we can do estimations based on different tasks which completed for this stage.

{
PartitionMemoryEstimator createPartitionMemoryEstimator();
}
Loading