diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index 9811d567d4a6..110f5bb4d6bc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -714,13 +714,13 @@ public List getMoreTasks() List result = new ArrayList<>(); while (true) { - boolean includeRemainder = splitSource.isFinished(); + boolean splitSourceFinished = splitSource.isFinished(); result.addAll(getReadyTasks( remotelyAccessibleSplitBuffer, ImmutableList.of(), new NodeRequirements(catalogRequirement, ImmutableSet.of(), taskMemory), - includeRemainder)); + splitSourceFinished)); for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { result.addAll(getReadyTasks( locallyAccessibleSplitBuffer.get(remoteHost), @@ -729,10 +729,10 @@ public List getMoreTasks() .map(Map.Entry::getValue) .collect(toImmutableList()), new NodeRequirements(catalogRequirement, ImmutableSet.of(remoteHost), taskMemory), - includeRemainder)); + splitSourceFinished)); } - if (!result.isEmpty() || splitSource.isFinished()) { + if (!result.isEmpty() || splitSourceFinished) { break; } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java index 48600fe0f265..9fd17f2e3a3b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java @@ -46,6 +46,7 @@ import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.IdentityHashMap; import java.util.List; @@ -58,6 +59,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Multimaps.toMultimap; +import static com.google.common.collect.Streams.findLast; import static io.airlift.slice.SizeOf.estimatedSizeOf; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.BYTE; @@ -790,6 +792,30 @@ public void testSourceDistributionTaskSourceWithWeights() assertTrue(taskSource.isFinished()); } + @Test + public void testSourceDistributionTaskSourceLastIncompleteTaskAlwaysCreated() + { + for (int targetSplitsPerTask = 1; targetSplitsPerTask <= 21; targetSplitsPerTask += 5) { + List splits = new ArrayList<>(); + for (int i = 0; i < targetSplitsPerTask + 1 /* to make last task incomplete with only a single split */; i++) { + splits.add(createWeightedSplit(i, STANDARD_WEIGHT)); + } + for (int finishDelayIterations = 1; finishDelayIterations < 20; finishDelayIterations++) { + TaskSource taskSource = createSourceDistributionTaskSource( + new TestingSplitSource(CATALOG, splits, finishDelayIterations), + ImmutableListMultimap.of(), + 1, + targetSplitsPerTask, + STANDARD_WEIGHT * targetSplitsPerTask, + targetSplitsPerTask); + List tasks = readAllTasks(taskSource); + assertThat(tasks).hasSize(2); + TaskDescriptor lastTask = findLast(tasks.stream()).orElseThrow(); + assertThat(lastTask.getSplits()).hasSize(1); + } + } + } + private static SourceDistributionTaskSource createSourceDistributionTaskSource( List splits, ListMultimap replicatedSources, @@ -797,12 +823,29 @@ private static SourceDistributionTaskSource createSourceDistributionTaskSource( int minSplitsPerTask, long splitWeightPerTask, int maxSplitsPerTask) + { + return createSourceDistributionTaskSource( + new TestingSplitSource(CATALOG, splits), + replicatedSources, + splitBatchSize, + minSplitsPerTask, + splitWeightPerTask, + maxSplitsPerTask); + } + + private static SourceDistributionTaskSource createSourceDistributionTaskSource( + SplitSource splitSource, + ListMultimap replicatedSources, + int splitBatchSize, + int minSplitsPerTask, + long splitWeightPerTask, + int maxSplitsPerTask) { return new SourceDistributionTaskSource( new QueryId("query"), PLAN_NODE_1, new TableExecuteContextManager(), - new TestingSplitSource(CATALOG, splits), + splitSource, replicatedSources, splitBatchSize, getSplitsTime -> {}, diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java index f2d30c7a65fa..6a0692d19f9f 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java @@ -33,11 +33,18 @@ public class TestingSplitSource { private final CatalogName catalogName; private final Iterator splits; + private int finishDelayRemainingIterations; public TestingSplitSource(CatalogName catalogName, List splits) + { + this(catalogName, splits, 0); + } + + public TestingSplitSource(CatalogName catalogName, List splits, int finishDelayIterations) { this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")).iterator(); + this.finishDelayRemainingIterations = finishDelayIterations; } @Override @@ -70,7 +77,7 @@ public void close() @Override public boolean isFinished() { - return !splits.hasNext(); + return !splits.hasNext() && finishDelayRemainingIterations-- <= 0; } @Override