From c9f6e77e446251a2ce7e07dd595a5196e6527663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Fri, 11 Feb 2022 14:29:12 +0100 Subject: [PATCH] Join intermediate tasks for small partitions If partitions produced by upstream tasks are small it is sub-optimal to create a separate task for each partition. With this commit, a single task can read data from multiple input partitions; target input size is configured via fault-tolerant-execution-target-task-input-size. If the task is also reading source data (could be the case e.g if there is a join vs bucketed table and join key matches bucketing), the task sizing takes input split weights into account (configured via fault-tolerant-execution-target-task-split-count). --- .../scheduler/StageTaskSourceFactory.java | 101 ++++++++++++-- .../scheduler/TestStageTaskSourceFactory.java | 127 +++++++++++++++++- 2 files changed, 213 insertions(+), 15 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index f22c0de90305..84ae94b7826c 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.airlift.units.DataSize; @@ -141,11 +142,14 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnect session, fragment, splitSourceFactory, + sourceExchanges, exchangeSourceHandles, splitBatchSize, getSplitTimeRecorder, bucketToPartitionMap.orElseThrow(() -> new IllegalArgumentException("bucketToPartitionMap is expected to be present for hash distributed stages")), - bucketNodeMap); + bucketNodeMap, + getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), + getFaultTolerantExecutionTargetTaskInputSize(session)); } else if (partitioning.equals(SOURCE_DISTRIBUTION)) { return SourceDistributionTaskSource.create( @@ -324,13 +328,17 @@ public static class HashDistributionTaskSource implements TaskSource { private final Map splitSources; + private final IdentityHashMap exchangeForHandle; private final Multimap partitionedExchangeSourceHandles; private final Multimap replicatedExchangeSourceHandles; + private final int splitBatchSize; private final LongConsumer getSplitTimeRecorder; private final int[] bucketToPartitionMap; private final Optional bucketNodeMap; private final Optional catalogRequirement; + private final long targetPartitionSourceSizeInBytes; // compared data read from ExchangeSources + private final long targetPartitionSplitWeight; // compared against splits from SplitSources private boolean finished; private boolean closed; @@ -339,36 +347,47 @@ public static HashDistributionTaskSource create( Session session, PlanFragment fragment, SplitSourceFactory splitSourceFactory, + Map sourceExchanges, Multimap exchangeSourceHandles, int splitBatchSize, LongConsumer getSplitTimeRecorder, int[] bucketToPartitionMap, - Optional bucketNodeMap) + Optional bucketNodeMap, + long targetPartitionSplitWeight, + DataSize targetPartitionSourceSize) { checkArgument(bucketNodeMap.isPresent() || fragment.getPartitionedSources().isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)"); Map splitSources = splitSourceFactory.createSplitSources(session, fragment); + return new HashDistributionTaskSource( splitSources, + getExchangeForHandleMap(sourceExchanges, exchangeSourceHandles), getPartitionedExchangeSourceHandles(fragment, exchangeSourceHandles), getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles), splitBatchSize, getSplitTimeRecorder, bucketToPartitionMap, bucketNodeMap, - fragment.getPartitioning().getConnectorId()); + fragment.getPartitioning().getConnectorId(), + targetPartitionSplitWeight, targetPartitionSourceSize); } public HashDistributionTaskSource( Map splitSources, + IdentityHashMap exchangeForHandle, Multimap partitionedExchangeSourceHandles, Multimap replicatedExchangeSourceHandles, int splitBatchSize, LongConsumer getSplitTimeRecorder, int[] bucketToPartitionMap, Optional bucketNodeMap, - Optional catalogRequirement) + Optional catalogRequirement, + long targetPartitionSplitWeight, + DataSize targetPartitionSourceSize) { this.splitSources = ImmutableMap.copyOf(requireNonNull(splitSources, "splitSources is null")); + this.exchangeForHandle = new IdentityHashMap<>(); + this.exchangeForHandle.putAll(exchangeForHandle); this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(partitionedExchangeSourceHandles, "partitionedExchangeSourceHandles is null")); this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null")); this.splitBatchSize = splitBatchSize; @@ -377,6 +396,8 @@ public HashDistributionTaskSource( this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(bucketNodeMap.isPresent() || splitSources.isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)"); this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); + this.targetPartitionSourceSizeInBytes = requireNonNull(targetPartitionSourceSize, "targetPartitionSourceSize is null").toBytes(); + this.targetPartitionSplitWeight = targetPartitionSplitWeight; } @Override @@ -430,20 +451,84 @@ public List getMoreTasks() } int taskPartitionId = 0; - ImmutableList.Builder result = ImmutableList.builder(); + ImmutableList.Builder partitionTasks = ImmutableList.builder(); for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) { ListMultimap splits = partitionToSplitsMap.getOrDefault(partition, ImmutableListMultimap.of()); ListMultimap exchangeSourceHandles = ImmutableListMultimap.builder() .putAll(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableMultimap.of())) - .putAll(replicatedExchangeSourceHandles) + // replicated exchange source will be added in postprocessTasks below .build(); HostAddress host = partitionToNodeMap.get(partition); Set hostRequirement = host == null ? ImmutableSet.of() : ImmutableSet.of(host); - result.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement))); + partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement))); } + List result = postprocessTasks(partitionTasks.build()); + finished = true; - return result.build(); + return result; + } + + private List postprocessTasks(List tasks) + { + ListMultimap taskGroups = groupCompatibleTasks(tasks); + ImmutableList.Builder joinedTasks = ImmutableList.builder(); + long replicatedExchangeSourcesSize = replicatedExchangeSourceHandles.values().stream().mapToLong(this::sourceHandleSize).sum(); + int taskPartitionId = 0; + for (Map.Entry> taskGroup : taskGroups.asMap().entrySet()) { + NodeRequirements groupNodeRequirements = taskGroup.getKey(); + Collection groupTasks = taskGroup.getValue(); + + ImmutableListMultimap.Builder splits = ImmutableListMultimap.builder(); + ImmutableListMultimap.Builder exchangeSources = ImmutableListMultimap.builder(); + long splitsWeight = 0; + long exchangeSourcesSize = 0; + + for (TaskDescriptor task : groupTasks) { + ListMultimap taskSplits = task.getSplits(); + ListMultimap taskExchangeSources = task.getExchangeSourceHandles(); + long taskSplitWeight = taskSplits.values().stream().mapToLong(split -> split.getSplitWeight().getRawValue()).sum(); + long taskExchangeSourcesSize = taskExchangeSources.values().stream().mapToLong(this::sourceHandleSize).sum(); + + if ((splitsWeight > 0 || exchangeSourcesSize > 0) + && ((splitsWeight + taskSplitWeight) > targetPartitionSplitWeight || (exchangeSourcesSize + taskExchangeSourcesSize + replicatedExchangeSourcesSize) > targetPartitionSourceSizeInBytes)) { + exchangeSources.putAll(replicatedExchangeSourceHandles); // add replicated exchanges + joinedTasks.add(new TaskDescriptor(taskPartitionId++, splits.build(), exchangeSources.build(), groupNodeRequirements)); + splits = ImmutableListMultimap.builder(); + exchangeSources = ImmutableListMultimap.builder(); + splitsWeight = 0; + exchangeSourcesSize = 0; + } + + splits.putAll(taskSplits); + exchangeSources.putAll(taskExchangeSources); + splitsWeight += taskSplitWeight; + exchangeSourcesSize += taskExchangeSourcesSize; + } + + ImmutableListMultimap remainderSplits = splits.build(); + ImmutableListMultimap remainderExchangeSources = exchangeSources.build(); + if (!remainderSplits.isEmpty() || !remainderExchangeSources.isEmpty()) { + remainderExchangeSources = ImmutableListMultimap.builder() + .putAll(remainderExchangeSources) + .putAll(replicatedExchangeSourceHandles) // add replicated exchanges + .build(); + joinedTasks.add(new TaskDescriptor(taskPartitionId++, remainderSplits, remainderExchangeSources, groupNodeRequirements)); + } + } + return joinedTasks.build(); + } + + private long sourceHandleSize(ExchangeSourceHandle handle) + { + Exchange exchange = exchangeForHandle.get(handle); + ExchangeSourceStatistics exchangeSourceStatistics = exchange.getExchangeSourceStatistics(handle); + return exchangeSourceStatistics.getSizeInBytes(); + } + + private ListMultimap groupCompatibleTasks(List tasks) + { + return Multimaps.index(tasks, TaskDescriptor::getNodeRequirements); } private int getPartitionForBucket(int bucket) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java index 8c95ff66548e..49276ce8360a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java @@ -61,6 +61,7 @@ import static io.airlift.slice.SizeOf.estimatedSizeOf; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.BYTE; +import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.trino.spi.exchange.ExchangeId.createRandomExchangeId; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -299,7 +300,9 @@ public void testHashDistributionTaskSource() ImmutableListMultimap.of(), 1, new int[] {0, 1, 2, 3}, - Optional.empty()); + Optional.empty(), + 0, + DataSize.of(3, BYTE)); assertFalse(taskSource.isFinished()); assertEquals(taskSource.getMoreTasks(), ImmutableList.of()); assertTrue(taskSource.isFinished()); @@ -315,7 +318,9 @@ PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 1, new int[] {0, 1, 2, 3}, - Optional.empty()); + Optional.empty(), + 0, + DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); assertEquals(taskSource.getMoreTasks(), ImmutableList.of( new TaskDescriptor(0, ImmutableListMultimap.of(), ImmutableListMultimap.of( @@ -344,7 +349,9 @@ PLAN_NODE_5, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit4))), PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 1, new int[] {0, 1, 2, 3}, - Optional.of(getTestingBucketNodeMap(4))); + Optional.of(getTestingBucketNodeMap(4)), + 0, + DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); assertEquals(taskSource.getMoreTasks(), ImmutableList.of( new TaskDescriptor( @@ -386,7 +393,9 @@ PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 1, new int[] {0, 1, 2, 3}, - Optional.of(getTestingBucketNodeMap(4))); + Optional.of(getTestingBucketNodeMap(4)), + 0, + DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); assertEquals(taskSource.getMoreTasks(), ImmutableList.of( new TaskDescriptor( @@ -431,7 +440,8 @@ PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), 2, new int[] {0, 1, 0, 1}, - Optional.of(getTestingBucketNodeMap(4))); + Optional.of(getTestingBucketNodeMap(4)), + 0, DataSize.of(0, BYTE)); assertFalse(taskSource.isFinished()); assertEquals(taskSource.getMoreTasks(), ImmutableList.of( new TaskDescriptor( @@ -452,6 +462,98 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Option PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); assertTrue(taskSource.isFinished()); + + // join based on split target split weight + taskSource = createHashDistributionTaskSource( + ImmutableMap.of( + PLAN_NODE_4, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), + PLAN_NODE_5, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit4))), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1)), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + 2, + new int[] {0, 1, 2, 3}, + Optional.of(getTestingBucketNodeMap(4)), + 2 * STANDARD_WEIGHT, + DataSize.of(100, GIGABYTE)); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit1, + PLAN_NODE_5, bucketedSplit4), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(1, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit2, + PLAN_NODE_4, bucketedSplit3), + ImmutableListMultimap.of( + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + assertTrue(taskSource.isFinished()); + + // join based on target exchange size + taskSource = createHashDistributionTaskSource( + ImmutableMap.of( + PLAN_NODE_4, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), + PLAN_NODE_5, new TestingSplitSource(CATALOG, ImmutableList.of(bucketedSplit4))), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 20), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 30), + PLAN_NODE_2, new TestingExchangeSourceHandle(1, 20), + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 99), + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 30)), + ImmutableListMultimap.of( + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + 2, + new int[] {0, 1, 2, 3}, + Optional.of(getTestingBucketNodeMap(4)), + 100 * STANDARD_WEIGHT, + DataSize.of(100, BYTE)); + assertFalse(taskSource.isFinished()); + assertEquals(taskSource.getMoreTasks(), ImmutableList.of( + new TaskDescriptor( + 0, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit1, + PLAN_NODE_5, bucketedSplit4), + ImmutableListMultimap.of( + PLAN_NODE_1, new TestingExchangeSourceHandle(0, 20), + PLAN_NODE_1, new TestingExchangeSourceHandle(1, 30), + PLAN_NODE_2, new TestingExchangeSourceHandle(1, 20), + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + new TaskDescriptor( + 1, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit2), + ImmutableListMultimap.of( + PLAN_NODE_2, new TestingExchangeSourceHandle(2, 99), + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + new TaskDescriptor( + 2, + ImmutableListMultimap.of( + PLAN_NODE_4, bucketedSplit3), + ImmutableListMultimap.of( + PLAN_NODE_2, new TestingExchangeSourceHandle(3, 30), + PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + assertTrue(taskSource.isFinished()); } private static HashDistributionTaskSource createHashDistributionTaskSource( @@ -460,17 +562,28 @@ private static HashDistributionTaskSource createHashDistributionTaskSource( Multimap replicatedExchangeSources, int splitBatchSize, int[] bucketToPartitionMap, - Optional bucketNodeMap) + Optional bucketNodeMap, + long targetPartitionSplitWeight, + DataSize targetPartitionSourceSize) { + // Craft exchangeSoureHandle -> Exchange map. Any TestingExchange instance can do - we need it only for getExchangeSourceStatistics + TestingExchange exchage = new TestingExchange(false); + IdentityHashMap exchangeForHandleMap = new IdentityHashMap<>(); + partitionedExchangeSources.values().forEach(handle -> exchangeForHandleMap.put(handle, exchage)); + replicatedExchangeSources.values().forEach(handle -> exchangeForHandleMap.put(handle, exchage)); + return new HashDistributionTaskSource( splitSources, + exchangeForHandleMap, partitionedExchangeSources, replicatedExchangeSources, splitBatchSize, getSplitsTime -> {}, bucketToPartitionMap, bucketNodeMap, - Optional.of(CATALOG)); + Optional.of(CATALOG), + targetPartitionSplitWeight, + targetPartitionSourceSize); } @Test