Skip to content

Commit

Permalink
Allow some tasks waiting or node per stage before blocking scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
losipiuk committed Apr 9, 2022
1 parent 153ddd4 commit 6a2e3bd
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ public final class SystemSessionProperties
public static final String QUERY_RETRY_ATTEMPTS = "query_retry_attempts";
public static final String TASK_RETRY_ATTEMPTS_OVERALL = "task_retry_attempts_overall";
public static final String TASK_RETRY_ATTEMPTS_PER_TASK = "task_retry_attempts_per_task";
public static final String MAX_TASKS_WAITING_FOR_NODE_PER_STAGE = "max_tasks_waiting_for_node_per_stage";
public static final String RETRY_INITIAL_DELAY = "retry_initial_delay";
public static final String RETRY_MAX_DELAY = "retry_max_delay";
public static final String HIDE_INACCESSIBLE_COLUMNS = "hide_inaccessible_columns";
Expand Down Expand Up @@ -732,6 +733,11 @@ public SystemSessionProperties(
"Maximum number of task retry attempts per single task",
queryManagerConfig.getTaskRetryAttemptsPerTask(),
false),
integerProperty(
MAX_TASKS_WAITING_FOR_NODE_PER_STAGE,
"Maximum possible number of tasks waiting for node allocation per stage before scheduling of new tasks for stage is paused",
queryManagerConfig.getMaxTasksWaitingForNodePerStage(),
false),
durationProperty(
RETRY_INITIAL_DELAY,
"Initial delay before initiating a retry attempt. Delay increases exponentially for each subsequent attempt up to 'retry_max_delay'",
Expand Down Expand Up @@ -1376,6 +1382,11 @@ public static int getTaskRetryAttemptsPerTask(Session session)
return session.getSystemProperty(TASK_RETRY_ATTEMPTS_PER_TASK, Integer.class);
}

public static int getMaxTasksWaitingForNodePerStage(Session session)
{
return session.getSystemProperty(MAX_TASKS_WAITING_FOR_NODE_PER_STAGE, Integer.class);
}

public static Duration getRetryInitialDelay(Session session)
{
return session.getSystemProperty(RETRY_INITIAL_DELAY, Duration.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ public class QueryManagerConfig
private Duration retryInitialDelay = new Duration(10, SECONDS);
private Duration retryMaxDelay = new Duration(1, MINUTES);

private int maxTasksWaitingForNodePerStage = 5;

private DataSize faultTolerantExecutionTargetTaskInputSize = DataSize.of(1, GIGABYTE);

private int faultTolerantExecutionMinTaskSplitCount = 16;
Expand Down Expand Up @@ -483,6 +485,20 @@ public QueryManagerConfig setRetryMaxDelay(Duration retryMaxDelay)
return this;
}

@Min(1)
public int getMaxTasksWaitingForNodePerStage()
{
return maxTasksWaitingForNodePerStage;
}

@Config("max-tasks-waiting-for-node-per-stage")
@ConfigDescription("Maximum possible number of tasks waiting for node allocation per stage before scheduling of new tasks for stage is paused")
public QueryManagerConfig setMaxTasksWaitingForNodePerStage(int maxTasksWaitingForNodePerStage)
{
this.maxTasksWaitingForNodePerStage = maxTasksWaitingForNodePerStage;
return this;
}

@NotNull
public DataSize getFaultTolerantExecutionTargetTaskInputSize()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.concurrent.MoreFutures;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.execution.ExecutionFailureInfo;
Expand Down Expand Up @@ -52,8 +53,10 @@
import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -98,6 +101,7 @@ public class FaultTolerantStageScheduler
private final TaskDescriptorStorage taskDescriptorStorage;
private final PartitionMemoryEstimator partitionMemoryEstimator;
private final int maxRetryAttemptsPerTask;
private final int maxTasksWaitingForNodePerStage;

private final TaskLifecycleListener taskLifecycleListener;
// empty when the results are consumed via a direct exchange
Expand All @@ -111,8 +115,6 @@ public class FaultTolerantStageScheduler
@GuardedBy("this")
private ListenableFuture<Void> blocked = immediateVoidFuture();

@GuardedBy("this")
private NodeAllocator.NodeLease nodeLease;
@GuardedBy("this")
private SettableFuture<Void> taskFinishedFuture;

Expand All @@ -131,6 +133,8 @@ public class FaultTolerantStageScheduler
@GuardedBy("this")
private final Queue<Integer> queuedPartitions = new ArrayDeque<>();
@GuardedBy("this")
private final Queue<PendingPartition> pendingPartitions = new ArrayDeque<>();
@GuardedBy("this")
private final Set<Integer> finishedPartitions = new HashSet<>();
@GuardedBy("this")
private final AtomicInteger remainingRetryAttemptsOverall;
Expand Down Expand Up @@ -159,7 +163,8 @@ public FaultTolerantStageScheduler(
Optional<int[]> sourceBucketToPartitionMap,
Optional<BucketNodeMap> sourceBucketNodeMap,
AtomicInteger remainingRetryAttemptsOverall,
int taskRetryAttemptsPerTask)
int taskRetryAttemptsPerTask,
int maxTasksWaitingForNodePerStage)
{
checkArgument(!stage.getFragment().getStageExecutionDescriptor().isStageGroupedExecution(), "grouped execution is expected to be disabled");

Expand All @@ -178,6 +183,7 @@ public FaultTolerantStageScheduler(
this.sourceBucketNodeMap = requireNonNull(sourceBucketNodeMap, "sourceBucketNodeMap is null");
this.remainingRetryAttemptsOverall = requireNonNull(remainingRetryAttemptsOverall, "remainingRetryAttemptsOverall is null");
this.maxRetryAttemptsPerTask = taskRetryAttemptsPerTask;
this.maxTasksWaitingForNodePerStage = maxTasksWaitingForNodePerStage;
}

public StageId getStageId()
Expand Down Expand Up @@ -236,8 +242,8 @@ public synchronized void schedule()
sourceBucketNodeMap);
}

while (!queuedPartitions.isEmpty() || !taskSource.isFinished()) {
while (queuedPartitions.isEmpty() && !taskSource.isFinished()) {
while (!pendingPartitions.isEmpty() || !queuedPartitions.isEmpty() || !taskSource.isFinished()) {
while (queuedPartitions.isEmpty() && pendingPartitions.size() < maxTasksWaitingForNodePerStage && !taskSource.isFinished()) {
List<TaskDescriptor> tasks = taskSource.getMoreTasks();
for (TaskDescriptor task : tasks) {
queuedPartitions.add(task.getPartitionId());
Expand All @@ -253,95 +259,119 @@ public synchronized void schedule()
}
}

if (queuedPartitions.isEmpty()) {
break;
Iterator<PendingPartition> pendingPartitionsIterator = pendingPartitions.iterator();
boolean startedTask = false;
while (pendingPartitionsIterator.hasNext()) {
PendingPartition pendingPartition = pendingPartitionsIterator.next();
if (pendingPartition.getNodeLease().getNode().isDone()) {
startTask(pendingPartition.getPartition(), pendingPartition.getNodeLease());
startedTask = true;
pendingPartitionsIterator.remove();
}
}

int partition = queuedPartitions.peek();
Optional<TaskDescriptor> taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition);
if (taskDescriptorOptional.isEmpty()) {
// query has been terminated
return;
if (!startedTask && (queuedPartitions.isEmpty() || pendingPartitions.size() >= maxTasksWaitingForNodePerStage)) {
break;
}
TaskDescriptor taskDescriptor = taskDescriptorOptional.get();

MemoryRequirements memoryRequirements = partitionMemoryRequirements.computeIfAbsent(partition, ignored -> partitionMemoryEstimator.getInitialMemoryRequirements(session, taskDescriptor.getNodeRequirements().getMemory()));
log.debug("Computed initial memory requirements for task from stage %s; requirements=%s; estimator=%s", stage.getStageId(), memoryRequirements, partitionMemoryEstimator);
if (nodeLease == null) {
while (pendingPartitions.size() < maxTasksWaitingForNodePerStage && !queuedPartitions.isEmpty()) {
int partition = queuedPartitions.poll();
Optional<TaskDescriptor> taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition);
if (taskDescriptorOptional.isEmpty()) {
// query has been terminated
return;
}
TaskDescriptor taskDescriptor = taskDescriptorOptional.get();

MemoryRequirements memoryRequirements = partitionMemoryRequirements.computeIfAbsent(partition, ignored -> partitionMemoryEstimator.getInitialMemoryRequirements(session, taskDescriptor.getNodeRequirements().getMemory()));
log.debug("Computed initial memory requirements for task from stage %s; requirements=%s; estimator=%s", stage.getStageId(), memoryRequirements, partitionMemoryEstimator);
NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements();
nodeRequirements = nodeRequirements.withMemory(memoryRequirements.getRequiredMemory());
nodeLease = nodeAllocator.acquire(nodeRequirements);
}
if (!nodeLease.getNode().isDone()) {
blocked = asVoid(nodeLease.getNode());
return;
}
InternalNode node = getFutureValue(nodeLease.getNode());
NodeAllocator.NodeLease nodeLease = nodeAllocator.acquire(nodeRequirements);

queuedPartitions.poll();
pendingPartitions.add(new PendingPartition(partition, nodeLease));
}
}

Multimap<PlanNodeId, Split> tableScanSplits = taskDescriptor.getSplits();
Multimap<PlanNodeId, Split> remoteSplits = createRemoteSplits(taskDescriptor.getExchangeSourceHandles());
List<ListenableFuture<?>> futures = new ArrayList<>();
if (taskFinishedFuture != null && !taskFinishedFuture.isDone()) {
futures.add(taskFinishedFuture);
}
for (PendingPartition pendingPartition : pendingPartitions) {
futures.add(pendingPartition.getNodeLease().getNode());
}
if (!futures.isEmpty()) {
blocked = asVoid(MoreFutures.whenAnyComplete(futures));
}
}

Multimap<PlanNodeId, Split> taskSplits = ImmutableListMultimap.<PlanNodeId, Split>builder()
.putAll(tableScanSplits)
.putAll(remoteSplits)
.build();
private void startTask(int partition, NodeAllocator.NodeLease nodeLease)
{
Optional<TaskDescriptor> taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition);
if (taskDescriptorOptional.isEmpty()) {
// query has been terminated
return;
}
TaskDescriptor taskDescriptor = taskDescriptorOptional.get();

int attemptId = getNextAttemptIdForPartition(partition);
InternalNode node = getFutureValue(nodeLease.getNode());

OutputBuffers outputBuffers;
Optional<ExchangeSinkInstanceHandle> exchangeSinkInstanceHandle;
if (sinkExchange.isPresent()) {
ExchangeSinkHandle sinkHandle = partitionToExchangeSinkHandleMap.get(partition);
exchangeSinkInstanceHandle = Optional.of(sinkExchange.get().instantiateSink(sinkHandle, attemptId));
outputBuffers = createSpoolingExchangeOutputBuffers(exchangeSinkInstanceHandle.get());
}
else {
exchangeSinkInstanceHandle = Optional.empty();
// stage will be consumed by the coordinator using direct exchange
outputBuffers = createInitialEmptyOutputBuffers(PARTITIONED)
.withBuffer(new OutputBuffers.OutputBufferId(0), 0)
.withNoMoreBufferIds();
}
Multimap<PlanNodeId, Split> tableScanSplits = taskDescriptor.getSplits();
Multimap<PlanNodeId, Split> remoteSplits = createRemoteSplits(taskDescriptor.getExchangeSourceHandles());

Set<PlanNodeId> allSourcePlanNodeIds = ImmutableSet.<PlanNodeId>builder()
.addAll(stage.getFragment().getPartitionedSources())
.addAll(stage.getFragment()
.getRemoteSourceNodes().stream()
.map(RemoteSourceNode::getId)
.iterator())
.build();

RemoteTask task = stage.createTask(
node,
partition,
attemptId,
sinkBucketToPartitionMap,
outputBuffers,
taskSplits,
allSourcePlanNodeIds.stream()
.collect(toImmutableListMultimap(Function.identity(), planNodeId -> Lifespan.taskWide())),
allSourcePlanNodeIds).orElseThrow(() -> new VerifyException("stage execution is expected to be active"));

partitionToRemoteTaskMap.put(partition, task);
runningTasks.put(task.getTaskId(), task);
runningNodes.put(task.getTaskId(), nodeLease);
nodeLease = null;

if (taskFinishedFuture == null) {
taskFinishedFuture = SettableFuture.create();
}
Multimap<PlanNodeId, Split> taskSplits = ImmutableListMultimap.<PlanNodeId, Split>builder()
.putAll(tableScanSplits)
.putAll(remoteSplits)
.build();

taskLifecycleListener.taskCreated(stage.getFragment().getId(), task);
int attemptId = getNextAttemptIdForPartition(partition);

task.addStateChangeListener(taskStatus -> updateTaskStatus(taskStatus, exchangeSinkInstanceHandle));
task.start();
OutputBuffers outputBuffers;
Optional<ExchangeSinkInstanceHandle> exchangeSinkInstanceHandle;
if (sinkExchange.isPresent()) {
ExchangeSinkHandle sinkHandle = partitionToExchangeSinkHandleMap.get(partition);
exchangeSinkInstanceHandle = Optional.of(sinkExchange.get().instantiateSink(sinkHandle, attemptId));
outputBuffers = createSpoolingExchangeOutputBuffers(exchangeSinkInstanceHandle.get());
}
else {
exchangeSinkInstanceHandle = Optional.empty();
// stage will be consumed by the coordinator using direct exchange
outputBuffers = createInitialEmptyOutputBuffers(PARTITIONED)
.withBuffer(new OutputBuffers.OutputBufferId(0), 0)
.withNoMoreBufferIds();
}

if (taskFinishedFuture != null && !taskFinishedFuture.isDone()) {
blocked = taskFinishedFuture;
Set<PlanNodeId> allSourcePlanNodeIds = ImmutableSet.<PlanNodeId>builder()
.addAll(stage.getFragment().getPartitionedSources())
.addAll(stage.getFragment()
.getRemoteSourceNodes().stream()
.map(RemoteSourceNode::getId)
.iterator())
.build();

RemoteTask task = stage.createTask(
node,
partition,
attemptId,
sinkBucketToPartitionMap,
outputBuffers,
taskSplits,
allSourcePlanNodeIds.stream()
.collect(toImmutableListMultimap(Function.identity(), planNodeId -> Lifespan.taskWide())),
allSourcePlanNodeIds).orElseThrow(() -> new VerifyException("stage execution is expected to be active"));

partitionToRemoteTaskMap.put(partition, task);
runningTasks.put(task.getTaskId(), task);
runningNodes.put(task.getTaskId(), nodeLease);

if (taskFinishedFuture == null) {
taskFinishedFuture = SettableFuture.create();
}

taskLifecycleListener.taskCreated(stage.getFragment().getId(), task);

task.addStateChangeListener(taskStatus -> updateTaskStatus(taskStatus, exchangeSinkInstanceHandle));
task.start();
}

public synchronized boolean isFinished()
Expand Down Expand Up @@ -383,7 +413,7 @@ private void close(boolean abort)
if (!closed) {
cancelRunningTasks(abort);
cancelBlockedFuture();
releaseAcquiredNode();
releasePendingNodes();
closeTaskSource();
closeSinkExchange();
}
Expand Down Expand Up @@ -415,15 +445,17 @@ private void cancelBlockedFuture()
}
}

private void releaseAcquiredNode()
private void releasePendingNodes()
{
verify(!Thread.holdsLock(this));
NodeAllocator.NodeLease lease;
List<NodeAllocator.NodeLease> leases = new ArrayList<>();
synchronized (this) {
lease = nodeLease;
nodeLease = null;
for (PendingPartition pendingPartition : pendingPartitions) {
leases.add(pendingPartition.getNodeLease());
}
pendingPartitions.clear();
}
if (lease != null) {
for (NodeAllocator.NodeLease lease : leases) {
lease.release();
}
}
Expand Down Expand Up @@ -599,4 +631,26 @@ private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo execut
REMOTE_HOST_GONE.toErrorCode(),
executionFailureInfo.getRemoteHost());
}

private static class PendingPartition
{
private final int partition;
private final NodeAllocator.NodeLease nodeLease;

public PendingPartition(int partition, NodeAllocator.NodeLease nodeLease)
{
this.partition = partition;
this.nodeLease = requireNonNull(nodeLease, "nodeLease is null");
}

public int getPartition()
{
return partition;
}

public NodeAllocator.NodeLease getNodeLease()
{
return nodeLease;
}
}
}
Loading

0 comments on commit 6a2e3bd

Please sign in to comment.