Skip to content

Commit

Permalink
Introduce FaultTolerantPartitioningScheme
Browse files Browse the repository at this point in the history
To encapsulate partition assignment logic
  • Loading branch information
arhimondr committed Sep 16, 2022
1 parent 65e99a2 commit 45eb8ef
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ public InternalNode getAssignedNode(Split split)
{
return getAssignedNode(getBucket(split));
}

public ToIntFunction<Split> getSplitToBucketFunction()
{
return splitToBucket;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;

import java.util.List;
import java.util.Optional;
import java.util.function.ToIntFunction;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class FaultTolerantPartitioningScheme
{
private final int partitionCount;
private final Optional<int[]> bucketToPartitionMap;
private final Optional<ToIntFunction<Split>> splitToBucketFunction;
private final Optional<List<InternalNode>> partitionToNodeMap;

@VisibleForTesting
FaultTolerantPartitioningScheme(
int partitionCount,
Optional<int[]> bucketToPartitionMap,
Optional<ToIntFunction<Split>> splitToBucketFunction,
Optional<List<InternalNode>> partitionToNodeMap)
{
checkArgument(partitionCount > 0, "partitionCount must be greater than zero");
this.partitionCount = partitionCount;
this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null");
this.splitToBucketFunction = requireNonNull(splitToBucketFunction, "splitToBucketFunction is null");
requireNonNull(partitionToNodeMap, "partitionToNodeMap is null");
partitionToNodeMap.ifPresent(map -> checkArgument(
map.size() == partitionCount,
"partitionToNodeMap size (%s) must be equal to partitionCount (%s)",
map.size(),
partitionCount));
this.partitionToNodeMap = partitionToNodeMap.map(ImmutableList::copyOf);
}

public int getPartitionCount()
{
return partitionCount;
}

public Optional<int[]> getBucketToPartitionMap()
{
return bucketToPartitionMap;
}

public int getPartition(Split split)
{
checkState(bucketToPartitionMap.isPresent(), "bucketToPartitionMap is expected to be present");
checkState(splitToBucketFunction.isPresent(), "splitToBucketFunction is expected to be present");
int bucket = splitToBucketFunction.get().applyAsInt(split);
checkState(
bucketToPartitionMap.get().length > bucket,
"invalid bucketToPartitionMap size (%s), bucket to partition mapping not found for bucket %s",
bucketToPartitionMap.get().length,
bucket);
return bucketToPartitionMap.get()[bucket];
}

public Optional<InternalNode> getNodeRequirement(int partition)
{
checkArgument(partition < partitionCount, "partition is expected to be less than %s", partitionCount);
return partitionToNodeMap.map(map -> map.get(partition));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.scheduler;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.InternalNode;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;

import javax.annotation.concurrent.NotThreadSafe;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;

import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static java.util.Objects.requireNonNull;

@NotThreadSafe
public class FaultTolerantPartitioningSchemeFactory
{
private final NodePartitioningManager nodePartitioningManager;
private final Session session;
private final int partitionCount;

private final Map<PartitioningHandle, FaultTolerantPartitioningScheme> cache = new HashMap<>();

public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartitioningManager, Session session, int partitionCount)
{
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
this.session = requireNonNull(session, "session is null");
this.partitionCount = partitionCount;
}

public FaultTolerantPartitioningScheme get(PartitioningHandle handle)
{
return cache.computeIfAbsent(handle, this::create);
}

private FaultTolerantPartitioningScheme create(PartitioningHandle partitioningHandle)
{
if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) {
return new FaultTolerantPartitioningScheme(
partitionCount,
Optional.of(IntStream.range(0, partitionCount).toArray()),
Optional.empty(),
Optional.empty());
}
if (partitioningHandle.getCatalogHandle().isPresent()) {
// TODO This caps the number of partitions to the number of available nodes. Perhaps a better approach is required for fault tolerant execution.
BucketNodeMap bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
int bucketCount = bucketNodeMap.getBucketCount();
int[] bucketToPartition = new int[bucketCount];
// make sure all buckets mapped to the same node map to the same partition, such that locality requirements are respected in scheduling
Map<InternalNode, Integer> nodeToPartition = new HashMap<>();
List<InternalNode> partitionToNodeMap = new ArrayList<>();
for (int bucket = 0; bucket < bucketCount; bucket++) {
InternalNode node = bucketNodeMap.getAssignedNode(bucket);
Integer partitionId = nodeToPartition.get(node);
if (partitionId == null) {
partitionId = partitionToNodeMap.size();
nodeToPartition.put(node, partitionId);
partitionToNodeMap.add(node);
}
bucketToPartition[bucket] = partitionId;
}
return new FaultTolerantPartitioningScheme(
partitionToNodeMap.size(),
Optional.of(bucketToPartition),
Optional.of(bucketNodeMap.getSplitToBucketFunction()),
Optional.of(ImmutableList.copyOf(partitionToNodeMap)));
}
return new FaultTolerantPartitioningScheme(1, Optional.empty(), Optional.empty(), Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.trino.execution.StageInfo;
import io.trino.execution.TaskId;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Metadata;
import io.trino.operator.RetryPolicy;
import io.trino.server.DynamicFilterService;
Expand All @@ -44,7 +43,6 @@
import io.trino.spi.exchange.ExchangeManager;
import io.trino.spi.exchange.ExchangeSourceHandle;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.plan.PlanFragmentId;
Expand All @@ -60,8 +58,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Ticker.systemTicker;
Expand All @@ -76,7 +72,6 @@
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.execution.QueryState.FINISHING;
import static io.trino.operator.RetryPolicy.TASK;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand Down Expand Up @@ -214,8 +209,10 @@ private Scheduler createScheduler()
});

Session session = queryStateMachine.getSession();
int partitionCount = getFaultTolerantExecutionPartitionCount(session);
Function<PartitioningHandle, BucketToPartition> bucketToPartitionCache = createBucketToPartitionCache(nodePartitioningManager, session, partitionCount);
FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory = new FaultTolerantPartitioningSchemeFactory(
nodePartitioningManager,
session,
getFaultTolerantExecutionPartitionCount(session));

ImmutableList.Builder<FaultTolerantStageScheduler> schedulers = ImmutableList.builder();
Map<PlanFragmentId, Exchange> exchanges = new HashMap<>();
Expand All @@ -235,9 +232,10 @@ private Scheduler createScheduler()

boolean outputStage = stageManager.getOutputStage().getStageId().equals(stage.getStageId());
ExchangeContext exchangeContext = new ExchangeContext(session.getQueryId(), new ExchangeId("external-exchange-" + stage.getStageId().getId()));
FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioningScheme().getPartitioning().getHandle());
Exchange exchange = exchangeManager.createExchange(
exchangeContext,
partitionCount,
sinkPartitioningScheme.getPartitionCount(),
// order of output records for coordinator consumed stages must be preserved as the stage
// may produce sorted dataset (for example an output of a global OrderByOperator)
outputStage);
Expand All @@ -256,7 +254,7 @@ private Scheduler createScheduler()
sourceExchanges.put(childFragmentId, sourceExchange);
}

BucketToPartition inputBucketToPartition = bucketToPartitionCache.apply(fragment.getPartitioning());
FaultTolerantPartitioningScheme sourcePartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioning());
FaultTolerantStageScheduler scheduler = new FaultTolerantStageScheduler(
session,
stage,
Expand All @@ -269,10 +267,9 @@ private Scheduler createScheduler()
(future, delay) -> scheduledExecutorService.schedule(() -> future.set(null), delay.toMillis(), MILLISECONDS),
systemTicker(),
exchange,
bucketToPartitionCache.apply(fragment.getPartitioningScheme().getPartitioning().getHandle()).getBucketToPartitionMap(),
sinkPartitioningScheme,
sourceExchanges.buildOrThrow(),
inputBucketToPartition.getBucketToPartitionMap(),
inputBucketToPartition.getBucketNodeMap(),
sourcePartitioningScheme,
remainingTaskRetryAttemptsOverall,
taskRetryAttemptsPerTask,
maxTasksWaitingForNodePerStage,
Expand Down Expand Up @@ -511,68 +508,6 @@ private void closeNodeAllocator()
}
}

private static Function<PartitioningHandle, BucketToPartition> createBucketToPartitionCache(NodePartitioningManager nodePartitioningManager, Session session, int partitionCount)
{
Map<PartitioningHandle, BucketToPartition> cachingMap = new HashMap<>();
return partitioningHandle ->
cachingMap.computeIfAbsent(
partitioningHandle,
handle -> createBucketToPartitionMap(session, partitionCount, handle, nodePartitioningManager));
}

private static BucketToPartition createBucketToPartitionMap(
Session session,
int partitionCount,
PartitioningHandle partitioningHandle,
NodePartitioningManager nodePartitioningManager)
{
if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) {
return new BucketToPartition(Optional.of(IntStream.range(0, partitionCount).toArray()), Optional.empty());
}
if (partitioningHandle.getCatalogHandle().isPresent()) {
BucketNodeMap bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
int bucketCount = bucketNodeMap.getBucketCount();
int[] bucketToPartition = new int[bucketCount];
// make sure all buckets mapped to the same node map to the same partition, such that locality requirements are respected in scheduling
Map<InternalNode, Integer> nodeToPartition = new HashMap<>();
int nextPartitionId = 0;
for (int bucket = 0; bucket < bucketCount; bucket++) {
InternalNode node = bucketNodeMap.getAssignedNode(bucket);
Integer partitionId = nodeToPartition.get(node);
if (partitionId == null) {
partitionId = nextPartitionId;
nextPartitionId++;
nodeToPartition.put(node, partitionId);
}
bucketToPartition[bucket] = partitionId;
}
return new BucketToPartition(Optional.of(bucketToPartition), Optional.of(bucketNodeMap));
}
return new BucketToPartition(Optional.empty(), Optional.empty());
}

private static class BucketToPartition
{
private final Optional<int[]> bucketToPartitionMap;
private final Optional<BucketNodeMap> bucketNodeMap;

private BucketToPartition(Optional<int[]> bucketToPartitionMap, Optional<BucketNodeMap> bucketNodeMap)
{
this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null");
this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null");
}

public Optional<int[]> getBucketToPartitionMap()
{
return bucketToPartitionMap;
}

public Optional<BucketNodeMap> getBucketNodeMap()
{
return bucketNodeMap;
}
}

private static boolean isFinishingOrDone(QueryStateMachine queryStateMachine)
{
QueryState queryState = queryStateMachine.getQueryState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,10 @@ public class FaultTolerantStageScheduler
private final int maxTasksWaitingForNodePerStage;

private final Exchange sinkExchange;
private final Optional<int[]> sinkBucketToPartitionMap;
private final FaultTolerantPartitioningScheme sinkPartitioningScheme;

private final Map<PlanFragmentId, Exchange> sourceExchanges;
private final Optional<int[]> sourceBucketToPartitionMap;
private final Optional<BucketNodeMap> sourceBucketNodeMap;
private final FaultTolerantPartitioningScheme sourcePartitioningScheme;

private final DelayedFutureCompletor futureCompletor;

Expand Down Expand Up @@ -187,10 +186,9 @@ public FaultTolerantStageScheduler(
DelayedFutureCompletor futureCompletor,
Ticker ticker,
Exchange sinkExchange,
Optional<int[]> sinkBucketToPartitionMap,
FaultTolerantPartitioningScheme sinkPartitioningScheme,
Map<PlanFragmentId, Exchange> sourceExchanges,
Optional<int[]> sourceBucketToPartitionMap,
Optional<BucketNodeMap> sourceBucketNodeMap,
FaultTolerantPartitioningScheme sourcePartitioningScheme,
AtomicInteger remainingRetryAttemptsOverall,
int taskRetryAttemptsPerTask,
int maxTasksWaitingForNodePerStage,
Expand All @@ -206,10 +204,9 @@ public FaultTolerantStageScheduler(
this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null");
this.futureCompletor = requireNonNull(futureCompletor, "futureCompletor is null");
this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null");
this.sinkBucketToPartitionMap = requireNonNull(sinkBucketToPartitionMap, "sinkBucketToPartitionMap is null");
this.sinkPartitioningScheme = requireNonNull(sinkPartitioningScheme, "sinkPartitioningScheme is null");
this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null"));
this.sourceBucketToPartitionMap = requireNonNull(sourceBucketToPartitionMap, "sourceBucketToPartitionMap is null");
this.sourceBucketNodeMap = requireNonNull(sourceBucketNodeMap, "sourceBucketNodeMap is null");
this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null");
this.remainingRetryAttemptsOverall = requireNonNull(remainingRetryAttemptsOverall, "remainingRetryAttemptsOverall is null");
this.maxRetryAttemptsPerTask = taskRetryAttemptsPerTask;
this.maxTasksWaitingForNodePerStage = maxTasksWaitingForNodePerStage;
Expand Down Expand Up @@ -277,8 +274,7 @@ public synchronized void schedule()
stage.getFragment(),
exchangeSources,
stage::recordGetSplitTime,
sourceBucketToPartitionMap,
sourceBucketNodeMap);
sourcePartitioningScheme);
}

while (!pendingPartitions.isEmpty() || !queuedPartitions.isEmpty() || !taskSource.isFinished()) {
Expand Down Expand Up @@ -393,7 +389,7 @@ private void startTask(int partition, NodeAllocator.NodeLease nodeLease, MemoryR
node,
partition,
attemptId,
sinkBucketToPartitionMap,
sinkPartitioningScheme.getBucketToPartitionMap(),
outputBuffers,
taskSplits,
allSourcePlanNodeIds,
Expand Down
Loading

0 comments on commit 45eb8ef

Please sign in to comment.