Skip to content

Commit

Permalink
Join intermediate tasks for small partitions
Browse files Browse the repository at this point in the history
If partitions produced by upstream tasks are small it is sub-optimal to
create a separate task for each partition.
With this commit, a single task can read data from multiple input
partitions; target input size is configured via
fault-tolerant-execution-target-task-input-size.
If the task is also reading source data (could be the case e.g if there
is a join vs bucketed table and join key matches bucketing), the task
sizing takes input split
weights into account (configured via
fault-tolerant-execution-target-task-split-count).
  • Loading branch information
losipiuk committed Mar 9, 2022
1 parent 8d0ff22 commit c9f6e77
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
Expand Down Expand Up @@ -141,11 +142,14 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnect
session,
fragment,
splitSourceFactory,
sourceExchanges,
exchangeSourceHandles,
splitBatchSize,
getSplitTimeRecorder,
bucketToPartitionMap.orElseThrow(() -> new IllegalArgumentException("bucketToPartitionMap is expected to be present for hash distributed stages")),
bucketNodeMap);
bucketNodeMap,
getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(),
getFaultTolerantExecutionTargetTaskInputSize(session));
}
else if (partitioning.equals(SOURCE_DISTRIBUTION)) {
return SourceDistributionTaskSource.create(
Expand Down Expand Up @@ -324,13 +328,17 @@ public static class HashDistributionTaskSource
implements TaskSource
{
private final Map<PlanNodeId, SplitSource> splitSources;
private final IdentityHashMap<ExchangeSourceHandle, Exchange> exchangeForHandle;
private final Multimap<PlanNodeId, ExchangeSourceHandle> partitionedExchangeSourceHandles;
private final Multimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles;

private final int splitBatchSize;
private final LongConsumer getSplitTimeRecorder;
private final int[] bucketToPartitionMap;
private final Optional<BucketNodeMap> bucketNodeMap;
private final Optional<CatalogName> catalogRequirement;
private final long targetPartitionSourceSizeInBytes; // compared data read from ExchangeSources
private final long targetPartitionSplitWeight; // compared against splits from SplitSources

private boolean finished;
private boolean closed;
Expand All @@ -339,36 +347,47 @@ public static HashDistributionTaskSource create(
Session session,
PlanFragment fragment,
SplitSourceFactory splitSourceFactory,
Map<PlanFragmentId, Exchange> sourceExchanges,
Multimap<PlanFragmentId, ExchangeSourceHandle> exchangeSourceHandles,
int splitBatchSize,
LongConsumer getSplitTimeRecorder,
int[] bucketToPartitionMap,
Optional<BucketNodeMap> bucketNodeMap)
Optional<BucketNodeMap> bucketNodeMap,
long targetPartitionSplitWeight,
DataSize targetPartitionSourceSize)
{
checkArgument(bucketNodeMap.isPresent() || fragment.getPartitionedSources().isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)");
Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(session, fragment);

return new HashDistributionTaskSource(
splitSources,
getExchangeForHandleMap(sourceExchanges, exchangeSourceHandles),
getPartitionedExchangeSourceHandles(fragment, exchangeSourceHandles),
getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles),
splitBatchSize,
getSplitTimeRecorder,
bucketToPartitionMap,
bucketNodeMap,
fragment.getPartitioning().getConnectorId());
fragment.getPartitioning().getConnectorId(),
targetPartitionSplitWeight, targetPartitionSourceSize);
}

public HashDistributionTaskSource(
Map<PlanNodeId, SplitSource> splitSources,
IdentityHashMap<ExchangeSourceHandle, Exchange> exchangeForHandle,
Multimap<PlanNodeId, ExchangeSourceHandle> partitionedExchangeSourceHandles,
Multimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles,
int splitBatchSize,
LongConsumer getSplitTimeRecorder,
int[] bucketToPartitionMap,
Optional<BucketNodeMap> bucketNodeMap,
Optional<CatalogName> catalogRequirement)
Optional<CatalogName> catalogRequirement,
long targetPartitionSplitWeight,
DataSize targetPartitionSourceSize)
{
this.splitSources = ImmutableMap.copyOf(requireNonNull(splitSources, "splitSources is null"));
this.exchangeForHandle = new IdentityHashMap<>();
this.exchangeForHandle.putAll(exchangeForHandle);
this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(partitionedExchangeSourceHandles, "partitionedExchangeSourceHandles is null"));
this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null"));
this.splitBatchSize = splitBatchSize;
Expand All @@ -377,6 +396,8 @@ public HashDistributionTaskSource(
this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null");
checkArgument(bucketNodeMap.isPresent() || splitSources.isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)");
this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null");
this.targetPartitionSourceSizeInBytes = requireNonNull(targetPartitionSourceSize, "targetPartitionSourceSize is null").toBytes();
this.targetPartitionSplitWeight = targetPartitionSplitWeight;
}

@Override
Expand Down Expand Up @@ -430,20 +451,84 @@ public List<TaskDescriptor> getMoreTasks()
}

int taskPartitionId = 0;
ImmutableList.Builder<TaskDescriptor> result = ImmutableList.builder();
ImmutableList.Builder<TaskDescriptor> partitionTasks = ImmutableList.builder();
for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) {
ListMultimap<PlanNodeId, Split> splits = partitionToSplitsMap.getOrDefault(partition, ImmutableListMultimap.of());
ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles = ImmutableListMultimap.<PlanNodeId, ExchangeSourceHandle>builder()
.putAll(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableMultimap.of()))
.putAll(replicatedExchangeSourceHandles)
// replicated exchange source will be added in postprocessTasks below
.build();
HostAddress host = partitionToNodeMap.get(partition);
Set<HostAddress> hostRequirement = host == null ? ImmutableSet.of() : ImmutableSet.of(host);
result.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement)));
partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement)));
}

List<TaskDescriptor> result = postprocessTasks(partitionTasks.build());

finished = true;
return result.build();
return result;
}

private List<TaskDescriptor> postprocessTasks(List<TaskDescriptor> tasks)
{
ListMultimap<NodeRequirements, TaskDescriptor> taskGroups = groupCompatibleTasks(tasks);
ImmutableList.Builder<TaskDescriptor> joinedTasks = ImmutableList.builder();
long replicatedExchangeSourcesSize = replicatedExchangeSourceHandles.values().stream().mapToLong(this::sourceHandleSize).sum();
int taskPartitionId = 0;
for (Map.Entry<NodeRequirements, Collection<TaskDescriptor>> taskGroup : taskGroups.asMap().entrySet()) {
NodeRequirements groupNodeRequirements = taskGroup.getKey();
Collection<TaskDescriptor> groupTasks = taskGroup.getValue();

ImmutableListMultimap.Builder<PlanNodeId, Split> splits = ImmutableListMultimap.builder();
ImmutableListMultimap.Builder<PlanNodeId, ExchangeSourceHandle> exchangeSources = ImmutableListMultimap.builder();
long splitsWeight = 0;
long exchangeSourcesSize = 0;

for (TaskDescriptor task : groupTasks) {
ListMultimap<PlanNodeId, Split> taskSplits = task.getSplits();
ListMultimap<PlanNodeId, ExchangeSourceHandle> taskExchangeSources = task.getExchangeSourceHandles();
long taskSplitWeight = taskSplits.values().stream().mapToLong(split -> split.getSplitWeight().getRawValue()).sum();
long taskExchangeSourcesSize = taskExchangeSources.values().stream().mapToLong(this::sourceHandleSize).sum();

if ((splitsWeight > 0 || exchangeSourcesSize > 0)
&& ((splitsWeight + taskSplitWeight) > targetPartitionSplitWeight || (exchangeSourcesSize + taskExchangeSourcesSize + replicatedExchangeSourcesSize) > targetPartitionSourceSizeInBytes)) {
exchangeSources.putAll(replicatedExchangeSourceHandles); // add replicated exchanges
joinedTasks.add(new TaskDescriptor(taskPartitionId++, splits.build(), exchangeSources.build(), groupNodeRequirements));
splits = ImmutableListMultimap.builder();
exchangeSources = ImmutableListMultimap.builder();
splitsWeight = 0;
exchangeSourcesSize = 0;
}

splits.putAll(taskSplits);
exchangeSources.putAll(taskExchangeSources);
splitsWeight += taskSplitWeight;
exchangeSourcesSize += taskExchangeSourcesSize;
}

ImmutableListMultimap<PlanNodeId, Split> remainderSplits = splits.build();
ImmutableListMultimap<PlanNodeId, ExchangeSourceHandle> remainderExchangeSources = exchangeSources.build();
if (!remainderSplits.isEmpty() || !remainderExchangeSources.isEmpty()) {
remainderExchangeSources = ImmutableListMultimap.<PlanNodeId, ExchangeSourceHandle>builder()
.putAll(remainderExchangeSources)
.putAll(replicatedExchangeSourceHandles) // add replicated exchanges
.build();
joinedTasks.add(new TaskDescriptor(taskPartitionId++, remainderSplits, remainderExchangeSources, groupNodeRequirements));
}
}
return joinedTasks.build();
}

private long sourceHandleSize(ExchangeSourceHandle handle)
{
Exchange exchange = exchangeForHandle.get(handle);
ExchangeSourceStatistics exchangeSourceStatistics = exchange.getExchangeSourceStatistics(handle);
return exchangeSourceStatistics.getSizeInBytes();
}

private ListMultimap<NodeRequirements, TaskDescriptor> groupCompatibleTasks(List<TaskDescriptor> tasks)
{
return Multimaps.index(tasks, TaskDescriptor::getNodeRequirements);
}

private int getPartitionForBucket(int bucket)
Expand Down
Loading

0 comments on commit c9f6e77

Please sign in to comment.