diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 7043021a7e3a..88d7ea4f7fb1 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -168,7 +168,6 @@ public final class SystemSessionProperties public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_GROWTH_FACTOR = "fault_tolerant_execution_task_memory_growth_factor"; public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE = "fault_tolerant_execution_task_memory_estimation_quantile"; public static final String FAULT_TOLERANT_EXECUTION_PARTITION_COUNT = "fault_tolerant_execution_partition_count"; - public static final String FAULT_TOLERANT_EXECUTION_PRESERVE_INPUT_PARTITIONS_IN_WRITE_STAGE = "fault_tolerant_execution_preserve_input_partitions_in_write_stage"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold"; @@ -177,6 +176,7 @@ public final class SystemSessionProperties public static final String FORCE_SPILLING_JOIN = "force_spilling_join"; public static final String FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED = "fault_tolerant_execution_event_driven_scheduler_enabled"; public static final String FORCE_FIXED_DISTRIBUTION_FOR_PARTITIONED_OUTPUT_OPERATOR_ENABLED = "force_fixed_distribution_for_partitioned_output_operator_enabled"; + public static final String FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED = "fault_tolerant_execution_force_preferred_write_partitioning_enabled"; private final List> sessionProperties; @@ -832,11 +832,6 @@ public SystemSessionProperties( "Number of partitions for distributed joins and aggregations executed with fault tolerant execution enabled", queryManagerConfig.getFaultTolerantExecutionPartitionCount(), false), - booleanProperty( - FAULT_TOLERANT_EXECUTION_PRESERVE_INPUT_PARTITIONS_IN_WRITE_STAGE, - "Ensure single task reads single hash partitioned input partition for stages which write table data", - queryManagerConfig.getFaultTolerantPreserveInputPartitionsInWriteStage(), - false), booleanProperty( ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, "When enabled, partial aggregation might be adaptively turned off when it does not provide any performance gain", @@ -877,6 +872,11 @@ public SystemSessionProperties( FORCE_FIXED_DISTRIBUTION_FOR_PARTITIONED_OUTPUT_OPERATOR_ENABLED, "Force partitioned output operator to be run with fixed distribution", optimizerConfig.isForceFixedDistributionForPartitionedOutputOperatorEnabled(), + true), + booleanProperty( + FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED, + "Force preferred write partitioning for fault tolerant execution", + queryManagerConfig.isFaultTolerantExecutionForcePreferredWritePartitioningEnabled(), true)); } @@ -1521,11 +1521,6 @@ public static double getFaultTolerantExecutionTaskMemoryEstimationQuantile(Sessi return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE, Double.class); } - public static boolean getFaultTolerantPreserveInputPartitionsInWriteStage(Session session) - { - return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_PRESERVE_INPUT_PARTITIONS_IN_WRITE_STAGE, Boolean.class); - } - public static int getFaultTolerantExecutionPartitionCount(Session session) { return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_PARTITION_COUNT, Integer.class); @@ -1570,4 +1565,13 @@ public static boolean isForceFixedDistributionForPartitionedOutputOperatorEnable { return session.getSystemProperty(FORCE_FIXED_DISTRIBUTION_FOR_PARTITIONED_OUTPUT_OPERATOR_ENABLED, Boolean.class); } + + public static boolean isFaultTolerantExecutionForcePreferredWritePartitioningEnabled(Session session) + { + if (!isFaultTolerantExecutionEventDriverSchedulerEnabled(session)) { + // supported only in event driven scheduler + return false; + } + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED, Boolean.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index a8cdd8e4d8cf..226b6f38d4a6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -97,8 +97,8 @@ public class QueryManagerConfig private int faultTolerantExecutionMaxTaskSplitCount = 256; private DataSize faultTolerantExecutionTaskDescriptorStorageMaxMemory = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15)); private int faultTolerantExecutionPartitionCount = 50; - private boolean faultTolerantPreserveInputPartitionsInWriteStage = true; private boolean faultTolerantExecutionEventDrivenSchedulerEnabled = true; + private boolean faultTolerantExecutionForcePreferredWritePartitioningEnabled = true; @Min(1) public int getScheduleSplitBatchSize() @@ -620,28 +620,27 @@ public QueryManagerConfig setFaultTolerantExecutionPartitionCount(int faultToler return this; } - public boolean getFaultTolerantPreserveInputPartitionsInWriteStage() + public boolean isFaultTolerantExecutionEventDrivenSchedulerEnabled() { - return faultTolerantPreserveInputPartitionsInWriteStage; + return faultTolerantExecutionEventDrivenSchedulerEnabled; } - @Config("fault-tolerant-execution-preserve-input-partitions-in-write-stage") - @ConfigDescription("Ensure single task reads single hash partitioned input partition for stages which write table data") - public QueryManagerConfig setFaultTolerantPreserveInputPartitionsInWriteStage(boolean faultTolerantPreserveInputPartitionsInWriteStage) + @Config("experimental.fault-tolerant-execution-event-driven-scheduler-enabled") + public QueryManagerConfig setFaultTolerantExecutionEventDrivenSchedulerEnabled(boolean faultTolerantExecutionEventDrivenSchedulerEnabled) { - this.faultTolerantPreserveInputPartitionsInWriteStage = faultTolerantPreserveInputPartitionsInWriteStage; + this.faultTolerantExecutionEventDrivenSchedulerEnabled = faultTolerantExecutionEventDrivenSchedulerEnabled; return this; } - public boolean isFaultTolerantExecutionEventDrivenSchedulerEnabled() + public boolean isFaultTolerantExecutionForcePreferredWritePartitioningEnabled() { - return faultTolerantExecutionEventDrivenSchedulerEnabled; + return faultTolerantExecutionForcePreferredWritePartitioningEnabled; } - @Config("experimental.fault-tolerant-execution-event-driven-scheduler-enabled") - public QueryManagerConfig setFaultTolerantExecutionEventDrivenSchedulerEnabled(boolean faultTolerantExecutionEventDrivenSchedulerEnabled) + @Config("experimental.fault-tolerant-execution-force-preferred-write-partitioning-enabled") + public QueryManagerConfig setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(boolean faultTolerantExecutionForcePreferredWritePartitioningEnabled) { - this.faultTolerantExecutionEventDrivenSchedulerEnabled = faultTolerantExecutionEventDrivenSchedulerEnabled; + this.faultTolerantExecutionForcePreferredWritePartitioningEnabled = faultTolerantExecutionForcePreferredWritePartitioningEnabled; return this; } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java index c8556526de35..c6c324d498eb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java @@ -31,9 +31,7 @@ import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.PlanVisitor; import io.trino.sql.planner.plan.RemoteSourceNode; -import io.trino.sql.planner.plan.TableWriterNode; import javax.inject.Inject; @@ -48,10 +46,10 @@ import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxTaskSplitCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount; -import static io.trino.SystemSessionProperties.getFaultTolerantPreserveInputPartitionsInWriteStage; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; @@ -186,43 +184,19 @@ private SplitAssigner createSplitAssigner( maxArbitraryDistributionTaskSplitCount); } if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() || - (partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) { - return new HashDistributionSplitAssigner( + (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) || + partitioning.equals(SCALED_WRITER_HASH_DISTRIBUTION)) { + return HashDistributionSplitAssigner.create( partitioning.getCatalogHandle(), partitionedSources, replicatedSources, - getFaultTolerantExecutionTargetTaskInputSize(session).toBytes(), - outputDataSizeEstimates, sourcePartitioningScheme, - getFaultTolerantPreserveInputPartitionsInWriteStage(session) && isWriteFragment(fragment)); + outputDataSizeEstimates, + fragment, + getFaultTolerantExecutionTargetTaskInputSize(session).toBytes()); } // other partitioning handles are not expected to be set as a fragment partitioning throw new IllegalArgumentException("Unexpected partitioning: " + partitioning); } - - private static boolean isWriteFragment(PlanFragment fragment) - { - PlanVisitor visitor = new PlanVisitor<>() - { - @Override - protected Boolean visitPlan(PlanNode node, Void context) - { - for (PlanNode child : node.getSources()) { - if (child.accept(this, context)) { - return true; - } - } - return false; - } - - @Override - public Boolean visitTableWriter(TableWriterNode node, Void context) - { - return true; - } - }; - - return fragment.getRoot().accept(visitor, null); - } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java index 23e9e41ae014..d8b2f6078cfb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java @@ -30,6 +30,7 @@ import java.util.stream.IntStream; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static java.util.Objects.requireNonNull; @NotThreadSafe @@ -55,7 +56,7 @@ public FaultTolerantPartitioningScheme get(PartitioningHandle handle) private FaultTolerantPartitioningScheme create(PartitioningHandle partitioningHandle) { - if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) { + if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_HASH_DISTRIBUTION)) { return new FaultTolerantPartitioningScheme( partitionCount, Optional.of(IntStream.range(0, partitionCount).toArray()), diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java index f536412d26ff..e2bea4e18b12 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/HashDistributionSplitAssigner.java @@ -13,6 +13,7 @@ */ package io.trino.execution.scheduler; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -24,7 +25,11 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.spi.HostAddress; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.PlanVisitor; +import io.trino.sql.planner.plan.TableWriterNode; import java.util.HashSet; import java.util.List; @@ -34,12 +39,16 @@ import java.util.PriorityQueue; import java.util.Set; import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.IntStream; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; class HashDistributionSplitAssigner @@ -57,14 +66,43 @@ class HashDistributionSplitAssigner private int nextTaskPartitionId; - HashDistributionSplitAssigner( + public static HashDistributionSplitAssigner create( Optional catalogRequirement, Set partitionedSources, Set replicatedSources, - long targetPartitionSizeInBytes, + FaultTolerantPartitioningScheme sourcePartitioningScheme, Map outputDataSizeEstimates, + PlanFragment fragment, + long targetPartitionSizeInBytes) + { + if (fragment.getPartitioning().equals(SCALED_WRITER_HASH_DISTRIBUTION)) { + verify( + + fragment.getPartitionedSources().isEmpty() && fragment.getRemoteSourceNodes().size() == 1, + "SCALED_WRITER_HASH_DISTRIBUTION fragments are expected to have exactly one remote source and no table scans"); + } + return new HashDistributionSplitAssigner( + catalogRequirement, + partitionedSources, + replicatedSources, + sourcePartitioningScheme, + createOutputPartitionToTaskPartition( + sourcePartitioningScheme, + partitionedSources, + outputDataSizeEstimates, + targetPartitionSizeInBytes, + sourceId -> fragment.getPartitioning().equals(SCALED_WRITER_HASH_DISTRIBUTION), + // never merge partitions for table write to avoid running into the maximum writers limit per task + !isWriteFragment(fragment))); + } + + @VisibleForTesting + HashDistributionSplitAssigner( + Optional catalogRequirement, + Set partitionedSources, + Set replicatedSources, FaultTolerantPartitioningScheme sourcePartitioningScheme, - boolean preserveOutputPartitioning) + Map outputPartitionToTaskPartition) { this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); this.replicatedSources = ImmutableSet.copyOf(requireNonNull(replicatedSources, "replicatedSources is null")); @@ -73,12 +111,7 @@ class HashDistributionSplitAssigner .addAll(replicatedSources) .build(); this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null"); - outputPartitionToTaskPartition = createOutputPartitionToTaskPartition( - sourcePartitioningScheme, - partitionedSources, - outputDataSizeEstimates, - preserveOutputPartitioning, - targetPartitionSizeInBytes); + this.outputPartitionToTaskPartition = ImmutableMap.copyOf(requireNonNull(outputPartitionToTaskPartition, "outputPartitionToTaskPartition is null")); } @Override @@ -93,33 +126,43 @@ public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap { TaskPartition taskPartition = outputPartitionToTaskPartition.get(outputPartitionId); verify(taskPartition != null, "taskPartition not found for outputPartitionId: %s", outputPartitionId); - if (!taskPartition.isIdAssigned()) { - // Assigns lazily to ensure task ids are incremental and with no gaps. - // Gaps can occur when scanning over a bucketed table as some buckets may contain no data. - taskPartition.assignId(nextTaskPartitionId++); + + List subPartitions; + if (taskPartition.getSplitBy().isPresent() && taskPartition.getSplitBy().get().equals(planNodeId)) { + subPartitions = ImmutableList.of(taskPartition.getNextSubPartition()); } - int taskPartitionId = taskPartition.getId(); - if (!createdTaskPartitions.contains(taskPartitionId)) { - Set hostRequirement = sourcePartitioningScheme.getNodeRequirement(outputPartitionId) - .map(InternalNode::getHostAndPort) - .map(ImmutableSet::of) - .orElse(ImmutableSet.of()); - assignment.addPartition(new Partition( - taskPartitionId, - new NodeRequirements(catalogRequirement, hostRequirement))); - for (PlanNodeId replicatedSource : replicatedSplits.keySet()) { - assignment.updatePartition(new PartitionUpdate(taskPartitionId, replicatedSource, replicatedSplits.get(replicatedSource), completedSources.contains(replicatedSource))); - } - for (PlanNodeId completedSource : completedSources) { - assignment.updatePartition(new PartitionUpdate(taskPartitionId, completedSource, ImmutableList.of(), true)); + else { + subPartitions = taskPartition.getSubPartitions(); + } + + for (SubPartition subPartition : subPartitions) { + if (!subPartition.isIdAssigned()) { + int taskPartitionId = nextTaskPartitionId++; + // Assigns lazily to ensure task ids are incremental and with no gaps. + // Gaps can occur when scanning over a bucketed table as some buckets may contain no data. + subPartition.assignId(taskPartitionId); + Set hostRequirement = sourcePartitioningScheme.getNodeRequirement(outputPartitionId) + .map(InternalNode::getHostAndPort) + .map(ImmutableSet::of) + .orElse(ImmutableSet.of()); + assignment.addPartition(new Partition( + taskPartitionId, + new NodeRequirements(catalogRequirement, hostRequirement))); + for (PlanNodeId replicatedSource : replicatedSplits.keySet()) { + assignment.updatePartition(new PartitionUpdate(taskPartitionId, replicatedSource, replicatedSplits.get(replicatedSource), completedSources.contains(replicatedSource))); + } + for (PlanNodeId completedSource : completedSources) { + assignment.updatePartition(new PartitionUpdate(taskPartitionId, completedSource, ImmutableList.of(), true)); + } + createdTaskPartitions.add(taskPartitionId); } - createdTaskPartitions.add(taskPartitionId); + + assignment.updatePartition(new PartitionUpdate(subPartition.getId(), planNodeId, ImmutableList.of(split), false)); } - assignment.updatePartition(new PartitionUpdate(taskPartitionId, planNodeId, splits.get(outputPartitionId), false)); - } + }); } if (noMoreSplits) { @@ -158,22 +201,23 @@ public AssignmentResult finish() return AssignmentResult.builder().build(); } - private static Map createOutputPartitionToTaskPartition( + @VisibleForTesting + static Map createOutputPartitionToTaskPartition( FaultTolerantPartitioningScheme sourcePartitioningScheme, Set partitionedSources, Map outputDataSizeEstimates, - boolean preserveOutputPartitioning, - long targetPartitionSizeInBytes) + long targetPartitionSizeInBytes, + Predicate canSplit, + boolean canMerge) { int partitionCount = sourcePartitioningScheme.getPartitionCount(); if (sourcePartitioningScheme.isExplicitPartitionToNodeMappingPresent() || partitionedSources.isEmpty() || - !outputDataSizeEstimates.keySet().containsAll(partitionedSources) || - preserveOutputPartitioning) { + !outputDataSizeEstimates.keySet().containsAll(partitionedSources)) { // if bucket scheme is set explicitly or if estimates are missing create one task partition per output partition return IntStream.range(0, partitionCount) .boxed() - .collect(toImmutableMap(Function.identity(), (key) -> new TaskPartition())); + .collect(toImmutableMap(Function.identity(), (key) -> new TaskPartition(1, Optional.empty()))); } List partitionedSourcesEstimates = outputDataSizeEstimates.entrySet().stream() @@ -183,20 +227,59 @@ private static Map createOutputPartitionToTaskPartition( OutputDataSizeEstimate mergedEstimate = OutputDataSizeEstimate.merge(partitionedSourcesEstimates); ImmutableMap.Builder result = ImmutableMap.builder(); PriorityQueue assignments = new PriorityQueue<>(); - assignments.add(new PartitionAssignment(new TaskPartition(), 0)); - for (int outputPartitionId = 0; outputPartitionId < partitionCount; outputPartitionId++) { - long outputPartitionSize = mergedEstimate.getPartitionSizeInBytes(outputPartitionId); - if (assignments.peek().assignedDataSizeInBytes() + outputPartitionSize > targetPartitionSizeInBytes - && assignments.size() < partitionCount) { - assignments.add(new PartitionAssignment(new TaskPartition(), 0)); + for (int partitionId = 0; partitionId < partitionCount; partitionId++) { + long partitionSizeInBytes = mergedEstimate.getPartitionSizeInBytes(partitionId); + if (assignments.isEmpty() || assignments.peek().assignedDataSizeInBytes() + partitionSizeInBytes > targetPartitionSizeInBytes || !canMerge) { + TaskPartition taskPartition = createTaskPartition( + partitionSizeInBytes, + targetPartitionSizeInBytes, + partitionedSources, + outputDataSizeEstimates, + partitionId, + canSplit); + result.put(partitionId, taskPartition); + assignments.add(new PartitionAssignment(taskPartition, partitionSizeInBytes)); + } + else { + PartitionAssignment assignment = assignments.poll(); + result.put(partitionId, assignment.taskPartition()); + assignments.add(new PartitionAssignment(assignment.taskPartition(), assignment.assignedDataSizeInBytes() + partitionSizeInBytes)); } - PartitionAssignment assignment = assignments.poll(); - result.put(outputPartitionId, assignment.taskPartition()); - assignments.add(new PartitionAssignment(assignment.taskPartition(), assignment.assignedDataSizeInBytes() + outputPartitionSize)); } return result.buildOrThrow(); } + private static TaskPartition createTaskPartition( + long partitionSizeInBytes, + long targetPartitionSizeInBytes, + Set partitionedSources, + Map outputDataSizeEstimates, + int partitionId, + Predicate canSplit) + { + if (partitionSizeInBytes > targetPartitionSizeInBytes) { + // try to assign multiple sub-partitions if possible + Map sourceSizes = getSourceSizes(partitionedSources, outputDataSizeEstimates, partitionId); + PlanNodeId largestSource = sourceSizes.entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .orElseThrow(); + long largestSourceSizeInBytes = sourceSizes.get(largestSource); + long remainingSourcesSizeInBytes = partitionSizeInBytes - largestSourceSizeInBytes; + if (remainingSourcesSizeInBytes <= targetPartitionSizeInBytes / 4 && canSplit.test(largestSource)) { + long targetLargestSourceSizeInBytes = targetPartitionSizeInBytes - remainingSourcesSizeInBytes; + return new TaskPartition(toIntExact(largestSourceSizeInBytes / targetLargestSourceSizeInBytes) + 1, Optional.of(largestSource)); + } + } + return new TaskPartition(1, Optional.empty()); + } + + private static Map getSourceSizes(Set partitionedSources, Map outputDataSizeEstimates, int partitionId) + { + return partitionedSources.stream() + .collect(toImmutableMap(Function.identity(), source -> outputDataSizeEstimates.get(source).getPartitionSizeInBytes(partitionId))); + } + private record PartitionAssignment(TaskPartition taskPartition, long assignedDataSizeInBytes) implements Comparable { @@ -213,12 +296,50 @@ public int compareTo(PartitionAssignment other) } } - private static class TaskPartition + @VisibleForTesting + static class TaskPartition + { + private final List subPartitions; + private final Optional splitBy; + + private int nextSubPartition; + + private TaskPartition(int subPartitionCount, Optional splitBy) + { + checkArgument(subPartitionCount > 0, "subPartitionCount is expected to be greater than zero"); + subPartitions = IntStream.range(0, subPartitionCount) + .mapToObj(i -> new SubPartition()) + .collect(toImmutableList()); + checkArgument(subPartitionCount == 1 || splitBy.isPresent(), "splitBy is expected to be present when subPartitionCount is greater than 1"); + this.splitBy = requireNonNull(splitBy, "splitBy is null"); + } + + public SubPartition getNextSubPartition() + { + SubPartition result = subPartitions.get(nextSubPartition); + nextSubPartition = (nextSubPartition + 1) % subPartitions.size(); + return result; + } + + public List getSubPartitions() + { + return subPartitions; + } + + public Optional getSplitBy() + { + return splitBy; + } + } + + @VisibleForTesting + static class SubPartition { private OptionalInt id = OptionalInt.empty(); public void assignId(int id) { + checkState(this.id.isEmpty(), "id is already assigned"); this.id = OptionalInt.of(id); } @@ -233,4 +354,29 @@ public int getId() return id.getAsInt(); } } + + private static boolean isWriteFragment(PlanFragment fragment) + { + PlanVisitor visitor = new PlanVisitor<>() + { + @Override + protected Boolean visitPlan(PlanNode node, Void context) + { + for (PlanNode child : node.getSources()) { + if (child.accept(this, context)) { + return true; + } + } + return false; + } + + @Override + public Boolean visitTableWriter(TableWriterNode node, Void context) + { + return true; + } + }; + + return fragment.getRoot().accept(visitor, null); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java index fc05d82fbab9..77f72877467d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java @@ -135,6 +135,8 @@ import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; @@ -847,7 +849,10 @@ public static DistributedStagesScheduler create( Map partitioningCacheMap = new HashMap<>(); Function partitioningCache = partitioningHandle -> - partitioningCacheMap.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(queryStateMachine.getSession(), handle)); + partitioningCacheMap.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap( + queryStateMachine.getSession(), + // TODO: support hash distributed writer scaling (https://github.com/trinodb/trino/issues/10791) + handle.equals(SCALED_WRITER_HASH_DISTRIBUTION) ? FIXED_HASH_DISTRIBUTION : handle)); Map> bucketToPartitionMap = createBucketToPartitionMap( coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(), diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index d4d22ad896e4..a306ca46fcd7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -91,7 +91,6 @@ import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinTaskSplitCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount; -import static io.trino.SystemSessionProperties.getFaultTolerantPreserveInputPartitionsInWriteStage; import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; @@ -178,7 +177,6 @@ public TaskSource create( sourcePartitioningScheme, getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), getFaultTolerantExecutionTargetTaskInputSize(session), - getFaultTolerantPreserveInputPartitionsInWriteStage(session), executor); } if (partitioning.equals(SOURCE_DISTRIBUTION)) { @@ -383,7 +381,6 @@ public static HashDistributionTaskSource create( FaultTolerantPartitioningScheme sourcePartitioningScheme, long targetPartitionSplitWeight, DataSize targetPartitionSourceSize, - boolean preserveInputPartitionsInWriteStage, Executor executor) { Map splitSources = splitSourceFactory.createSplitSources(session, fragment); @@ -396,7 +393,7 @@ public static HashDistributionTaskSource create( sourcePartitioningScheme, fragment.getPartitioning().getCatalogHandle(), targetPartitionSplitWeight, - (preserveInputPartitionsInWriteStage && isWriteFragment(fragment)) ? DataSize.of(0, BYTE) : targetPartitionSourceSize, + isWriteFragment(fragment) ? DataSize.of(0, BYTE) : targetPartitionSourceSize, executor); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index d55b8cf7f614..b28cf94c3689 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -34,7 +34,6 @@ import io.trino.spi.type.Type; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; -import io.trino.sql.planner.plan.MergeProcessorNode; import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanFragmentId; @@ -46,6 +45,7 @@ import io.trino.sql.planner.plan.SimpleTableExecuteNode; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -72,6 +72,8 @@ import static io.trino.spi.connector.StandardWarningCode.TOO_MANY_STAGES; import static io.trino.sql.planner.SchedulingOrderVisitor.scheduleOrder; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; @@ -323,24 +325,21 @@ public PlanNode visitRefreshMaterializedView(RefreshMaterializedViewNode node, R @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { - if (node.getPartitioningScheme().isPresent()) { - context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); - } + node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution(scheme.getPartitioning().getHandle(), metadata, session)); return context.defaultRewrite(node, context.get()); } @Override - public PlanNode visitMergeWriter(MergeWriterNode node, RewriteContext context) + public PlanNode visitTableExecute(TableExecuteNode node, RewriteContext context) { - if (node.getPartitioningScheme().isPresent()) { - context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); - } + node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution(scheme.getPartitioning().getHandle(), metadata, session)); return context.defaultRewrite(node, context.get()); } @Override - public PlanNode visitMergeProcessor(MergeProcessorNode node, RewriteContext context) + public PlanNode visitMergeWriter(MergeWriterNode node, RewriteContext context) { + node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution(scheme.getPartitioning().getHandle(), metadata, session)); return context.defaultRewrite(node, context.get()); } @@ -470,21 +469,26 @@ public FragmentProperties setDistribution(PartitioningHandle distribution, Metad PartitioningHandle currentPartitioning = this.partitioningHandle.get(); - if (isCompatibleSystemPartitioning(distribution)) { + if (currentPartitioning.equals(distribution)) { return this; } - if (currentPartitioning.equals(SOURCE_DISTRIBUTION)) { - this.partitioningHandle = Optional.of(distribution); + // If already system SINGLE or COORDINATOR_ONLY, leave it as is (this is for single-node execution) + if (currentPartitioning.isSingleNode()) { return this; } - // If already system SINGLE or COORDINATOR_ONLY, leave it as is (this is for single-node execution) - if (currentPartitioning.isSingleNode()) { + if (isCompatibleSystemPartitioning(distribution)) { return this; } - if (currentPartitioning.equals(distribution)) { + if (isCompatibleScaledWriterPartitioning(currentPartitioning, distribution)) { + this.partitioningHandle = Optional.of(distribution); + return this; + } + + if (currentPartitioning.equals(SOURCE_DISTRIBUTION)) { + this.partitioningHandle = Optional.of(distribution); return this; } @@ -512,6 +516,19 @@ private boolean isCompatibleSystemPartitioning(PartitioningHandle distribution) return false; } + private static boolean isCompatibleScaledWriterPartitioning(PartitioningHandle current, PartitioningHandle suggested) + { + if (current.equals(FIXED_HASH_DISTRIBUTION) && suggested.equals(SCALED_WRITER_HASH_DISTRIBUTION)) { + return true; + } + PartitioningHandle currentWithScaledWritersEnabled = new PartitioningHandle( + current.getCatalogHandle(), + current.getTransactionHandle(), + current.getConnectorHandle(), + true); + return currentWithScaledWritersEnabled.equals(suggested); + } + public FragmentProperties setCoordinatorOnlyDistribution() { if (partitioningHandle.isPresent() && partitioningHandle.get().isCoordinatorOnly()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java index f4a139428bb9..285929c303a8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java @@ -16,12 +16,15 @@ import io.trino.Session; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.operator.RetryPolicy; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.TableExecuteNode; import java.util.Optional; import static io.trino.SystemSessionProperties.getPreferredWritePartitioningMinNumberOfPartitions; +import static io.trino.SystemSessionProperties.getRetryPolicy; +import static io.trino.SystemSessionProperties.isFaultTolerantExecutionForcePreferredWritePartitioningEnabled; import static io.trino.SystemSessionProperties.isUsePreferredWritePartitioning; import static io.trino.cost.AggregationStatsRule.getRowsCount; import static io.trino.sql.planner.plan.Patterns.tableExecute; @@ -52,6 +55,11 @@ public boolean isEnabled(Session session) @Override public Result apply(TableExecuteNode node, Captures captures, Context context) { + if (getRetryPolicy(context.getSession()) == RetryPolicy.TASK && isFaultTolerantExecutionForcePreferredWritePartitioningEnabled(context.getSession())) { + // Choosing preferred partitioning introduces a risk of running into a skew (for example when writing to only a single partition). + // Fault tolerant execution can detect a potential skew automatically (based on runtime statistics) and mitigate it by splitting skewed partitions. + return enable(node); + } int minimumNumberOfPartitions = getPreferredWritePartitioningMinNumberOfPartitions(context.getSession()); if (minimumNumberOfPartitions <= 1) { // Force 'preferred write partitioning' even if stats are missing or broken diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java index 669bb790f3c2..9745122240fe 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java @@ -16,12 +16,15 @@ import io.trino.Session; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.operator.RetryPolicy; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.TableWriterNode; import java.util.Optional; import static io.trino.SystemSessionProperties.getPreferredWritePartitioningMinNumberOfPartitions; +import static io.trino.SystemSessionProperties.getRetryPolicy; +import static io.trino.SystemSessionProperties.isFaultTolerantExecutionForcePreferredWritePartitioningEnabled; import static io.trino.SystemSessionProperties.isUsePreferredWritePartitioning; import static io.trino.cost.AggregationStatsRule.getRowsCount; import static io.trino.sql.planner.plan.Patterns.tableWriterNode; @@ -57,6 +60,12 @@ public boolean isEnabled(Session session) @Override public Result apply(TableWriterNode node, Captures captures, Context context) { + if (getRetryPolicy(context.getSession()) == RetryPolicy.TASK && isFaultTolerantExecutionForcePreferredWritePartitioningEnabled(context.getSession())) { + // Choosing preferred partitioning introduces a risk of running into a skew (for example when writing to only a single partition). + // Fault tolerant execution can detect a potential skew automatically (based on runtime statistics) and mitigate it by splitting skewed partitions. + return enable(node); + } + int minimumNumberOfPartitions = getPreferredWritePartitioningMinNumberOfPartitions(context.getSession()); if (minimumNumberOfPartitions <= 1) { // Force 'preferred write partitioning' even if stats are missing or broken diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index bebc24758799..667ab168b37c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -31,10 +31,12 @@ import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; import io.trino.sql.planner.Partitioning; +import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.SystemPartitioningHandle; import io.trino.sql.planner.TypeAnalyzer; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -106,6 +108,7 @@ import static io.trino.sql.planner.FragmentTableScanCounter.hasMultipleSources; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.optimizations.ActualProperties.Global.partitionedOn; @@ -650,6 +653,28 @@ else if (redistributeWrites) { partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols())); } } + else if (scaleWriters + && writerTarget.supportsReportingWrittenBytes(plannerContext.getMetadata(), session) + && writerTarget.supportsMultipleWritersPerPartition(plannerContext.getMetadata(), session) + // do not insert an exchange if partitioning is compatible + && !newSource.getProperties().isCompatibleTablePartitioningWith(partitioningScheme.get().getPartitioning(), false, plannerContext.getMetadata(), session)) { + if (partitioningScheme.get().getPartitioning().getHandle().equals(FIXED_HASH_DISTRIBUTION)) { + partitioningScheme = Optional.of(partitioningScheme.get().withPartitioningHandle(SCALED_WRITER_HASH_DISTRIBUTION)); + } + else { + PartitioningHandle partitioningHandle = partitioningScheme.get().getPartitioning().getHandle(); + verify(!(partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle)); + verify( + partitioningScheme.get().getPartitioning().getArguments().stream().noneMatch(Partitioning.ArgumentBinding::isConstant), + "Table writer partitioning has constant arguments"); + partitioningScheme = Optional.of(partitioningScheme.get().withPartitioningHandle( + new PartitioningHandle( + partitioningHandle.getCatalogHandle(), + partitioningHandle.getTransactionHandle(), + partitioningHandle.getConnectorHandle(), + true))); + } + } if (partitioningScheme.isPresent() && !newSource.getProperties().isCompatibleTablePartitioningWith(partitioningScheme.get().getPartitioning(), false, plannerContext.getMetadata(), session)) { newSource = withDerivedProperties( partitionedExchange( diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index bcdefa55a00a..ccd4060f0a18 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -77,8 +77,8 @@ public void testDefaults() .setFaultTolerantExecutionMaxTaskSplitCount(256) .setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15))) .setFaultTolerantExecutionPartitionCount(50) - .setFaultTolerantPreserveInputPartitionsInWriteStage(true) - .setFaultTolerantExecutionEventDrivenSchedulerEnabled(true)); + .setFaultTolerantExecutionEventDrivenSchedulerEnabled(true) + .setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(true)); } @Test @@ -123,8 +123,8 @@ public void testExplicitPropertyMappings() .put("fault-tolerant-execution-max-task-split-count", "22") .put("fault-tolerant-execution-task-descriptor-storage-max-memory", "3GB") .put("fault-tolerant-execution-partition-count", "123") - .put("fault-tolerant-execution-preserve-input-partitions-in-write-stage", "false") .put("experimental.fault-tolerant-execution-event-driven-scheduler-enabled", "false") + .put("experimental.fault-tolerant-execution-force-preferred-write-partitioning-enabled", "false") .buildOrThrow(); QueryManagerConfig expected = new QueryManagerConfig() @@ -166,8 +166,8 @@ public void testExplicitPropertyMappings() .setFaultTolerantExecutionMaxTaskSplitCount(22) .setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize.of(3, GIGABYTE)) .setFaultTolerantExecutionPartitionCount(123) - .setFaultTolerantPreserveInputPartitionsInWriteStage(false) - .setFaultTolerantExecutionEventDrivenSchedulerEnabled(false); + .setFaultTolerantExecutionEventDrivenSchedulerEnabled(false) + .setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(false); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java index 589cd42d0563..d9137eff6863 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestHashDistributionSplitAssigner.java @@ -13,16 +13,18 @@ */ package io.trino.execution.scheduler; -import com.google.common.collect.HashMultimap; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; +import com.google.common.collect.Multimaps; import com.google.common.collect.SetMultimap; import com.google.common.primitives.ImmutableLongArray; import io.trino.client.NodeVersion; import io.trino.connector.CatalogHandle; +import io.trino.execution.scheduler.HashDistributionSplitAssigner.TaskPartition; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.sql.planner.plan.PlanNodeId; @@ -30,6 +32,7 @@ import java.net.URI; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -39,13 +42,17 @@ import java.util.function.Function; import java.util.stream.IntStream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Sets.difference; +import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap; import static io.trino.connector.CatalogHandle.createRootCatalogHandle; +import static io.trino.execution.scheduler.HashDistributionSplitAssigner.createOutputPartitionToTaskPartition; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; public class TestHashDistributionSplitAssigner { @@ -63,337 +70,399 @@ public class TestHashDistributionSplitAssigner @Test public void testEmpty() { - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of(new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), - 10, - Optional.empty(), - 1024, - ImmutableMap.of(), - false, - 1); - testAssigner( - ImmutableSet.of(), - ImmutableSet.of(REPLICATED_1), - ImmutableList.of(new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true)), - 1, - Optional.empty(), - 1024, - ImmutableMap.of(REPLICATED_1, new OutputDataSizeEstimate(ImmutableLongArray.builder().add(0).build())), - false, - 1); - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(REPLICATED_1), - ImmutableList.of( + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits(new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(10) + .withTargetPartitionSizeInBytes(1024) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + testAssigner() + .withReplicatedSources(REPLICATED_1) + .withSplits(new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(1) + .withTargetPartitionSizeInBytes(1024) + .withOutputDataSizeEstimates(ImmutableMap.of(REPLICATED_1, new OutputDataSizeEstimate(ImmutableLongArray.builder().add(0).build()))) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withReplicatedSources(REPLICATED_1) + .withSplits( new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true), - new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true)), - 10, - Optional.empty(), - 1024, - ImmutableMap.of(), - false, - 1); - testAssigner( - ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), - ImmutableSet.of(REPLICATED_1, REPLICATED_2), - ImmutableList.of( + new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(10) + .withTargetPartitionSizeInBytes(1024) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1, PARTITIONED_2) + .withReplicatedSources(REPLICATED_1, REPLICATED_2) + .withSplits( new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true), new SplitBatch(REPLICATED_1, ImmutableListMultimap.of(), true), new SplitBatch(PARTITIONED_2, ImmutableListMultimap.of(), true), - new SplitBatch(REPLICATED_2, ImmutableListMultimap.of(), true)), - 10, - Optional.empty(), - 1024, - ImmutableMap.of(), - false, - 1); + new SplitBatch(REPLICATED_2, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(10) + .withTargetPartitionSizeInBytes(1024) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); } @Test public void testExplicitPartitionToNodeMap() { - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), - 3, - Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3)), - 1000, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 3); + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)) + .withSplitPartitionCount(3) + .withPartitionToNodeMap(Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3))) + .withTargetPartitionSizeInBytes(1000) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); // some partitions missing - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)), - 3, - Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3)), - 1000, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 1); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)) + .withSplitPartitionCount(3) + .withPartitionToNodeMap(Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3))) + .withTargetPartitionSizeInBytes(1000) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); // no splits - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( - new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), - 3, - Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3)), - 1000, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 1); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(3) + .withPartitionToNodeMap(Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3))) + .withTargetPartitionSizeInBytes(1000) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); } @Test - public void testPreserveOutputPartitioning() + public void testMergeNotAllowed() { - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), - 3, - Optional.empty(), - 1000, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - true, - 3); + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1000) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(false) + .withExpectedTaskCount(3) + .run(); // some partitions missing - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)), - 3, - Optional.empty(), - 1000, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - true, - 1); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1000) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(false) + .withExpectedTaskCount(1) + .run(); // no splits - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( - new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), - 3, - Optional.empty(), - 1000, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - true, - 1); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1000) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(false) + .withExpectedTaskCount(1) + .run(); } @Test public void testMissingEstimates() { - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), - 3, - Optional.empty(), - 1000, - ImmutableMap.of(), - false, - 3); + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)) + .withSplitPartitionCount(3) + .withPartitionToNodeMap(Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3))) + .withTargetPartitionSizeInBytes(1000) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); // some partitions missing - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)), - 3, - Optional.empty(), - 1000, - ImmutableMap.of(), - false, - 1); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), true)) + .withSplitPartitionCount(3) + .withPartitionToNodeMap(Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3))) + .withTargetPartitionSizeInBytes(1000) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); // no splits - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( - new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)), - 3, - Optional.empty(), - 1000, - ImmutableMap.of(), - false, - 1); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, ImmutableListMultimap.of(), true)) + .withSplitPartitionCount(3) + .withPartitionToNodeMap(Optional.of(ImmutableList.of(NODE_1, NODE_2, NODE_3))) + .withTargetPartitionSizeInBytes(1000) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); } @Test public void testHappyPath() { - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(), - ImmutableList.of( + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)), - 3, - Optional.empty(), - 3, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 1); - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(REPLICATED_1), - ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 1), createSplit(3, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(3) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withReplicatedSources(REPLICATED_1) + .withSplits( new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), - 3, - Optional.empty(), - 3, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 1); - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(REPLICATED_1), - ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(3) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withReplicatedSources(REPLICATED_1) + .withSplits( new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), - 3, - Optional.empty(), - 1, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 3); - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(REPLICATED_1), - ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withReplicatedSources(REPLICATED_1) + .withSplits( new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), - 3, - Optional.empty(), - 1, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 3); - testAssigner( - ImmutableSet.of(PARTITIONED_1), - ImmutableSet.of(REPLICATED_1, REPLICATED_2), - ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withReplicatedSources(REPLICATED_1, REPLICATED_2) + .withSplits( new SplitBatch(REPLICATED_2, createSplitMap(createSplit(11, 1), createSplit(12, 100)), true), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), - 3, - Optional.empty(), - 1, - ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 3); - testAssigner( - ImmutableSet.of(PARTITIONED_1, PARTITIONED_2), - ImmutableSet.of(REPLICATED_1, REPLICATED_2), - ImmutableList.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1, PARTITIONED_2) + .withReplicatedSources(REPLICATED_1, REPLICATED_2) + .withSplits( new SplitBatch(REPLICATED_2, createSplitMap(createSplit(11, 1), createSplit(12, 100)), true), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(2, 0), createSplit(3, 2)), false), new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), new SplitBatch(PARTITIONED_2, createSplitMap(), true), new SplitBatch(REPLICATED_1, createSplitMap(createSplit(4, 1), createSplit(5, 100)), true), - new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)), - 3, - Optional.empty(), - 1, - ImmutableMap.of( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(6, 1), createSplit(7, 2)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(1) + .withOutputDataSizeEstimates(ImmutableMap.of( PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)), - PARTITIONED_2, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1))), - false, - 3); + PARTITIONED_2, new OutputDataSizeEstimate(ImmutableLongArray.of(1, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); } - private static void testAssigner( - Set partitionedSources, - Set replicatedSources, - List batches, - int splitPartitionCount, - Optional> partitionToNodeMap, - long targetPartitionSizeInBytes, - Map outputDataSizeEstimates, - boolean preserveOutputPartitioning, - int expectedTaskCount) + @Test + public void testPartitionSplitting() { - FaultTolerantPartitioningScheme partitioningScheme = createPartitioningScheme(splitPartitionCount, partitionToNodeMap); - HashDistributionSplitAssigner assigner = new HashDistributionSplitAssigner( - Optional.of(TESTING_CATALOG_HANDLE), - partitionedSources, - replicatedSources, - targetPartitionSizeInBytes, - outputDataSizeEstimates, - partitioningScheme, - preserveOutputPartitioning); - TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); - SetMultimap partitionedSplitIds = HashMultimap.create(); - Set replicatedSplitIds = new HashSet<>(); - for (SplitBatch batch : batches) { - assigner.assign(batch.getPlanNodeId(), batch.getSplits(), batch.isNoMoreSplits()).update(callback); - boolean replicated = replicatedSources.contains(batch.getPlanNodeId()); - callback.checkContainsSplits(batch.getPlanNodeId(), batch.getSplits().values(), replicated); - for (Map.Entry entry : batch.getSplits().entries()) { - int splitId = TestingConnectorSplit.getSplitId(entry.getValue()); - if (replicated) { - assertThat(replicatedSplitIds).doesNotContain(splitId); - replicatedSplitIds.add(splitId); - } - else { - partitionedSplitIds.put(entry.getKey(), splitId); - } - } - } - assigner.finish().update(callback); - List taskDescriptors = callback.getTaskDescriptors(); - assertThat(taskDescriptors).hasSize(expectedTaskCount); - for (TaskDescriptor taskDescriptor : taskDescriptors) { - int partitionId = taskDescriptor.getPartitionId(); - NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements(); - assertEquals(nodeRequirements.getCatalogHandle(), Optional.of(TESTING_CATALOG_HANDLE)); - partitionToNodeMap.ifPresent(partitionToNode -> { - if (!taskDescriptor.getSplits().isEmpty()) { - InternalNode node = partitionToNode.get(partitionId); - assertThat(nodeRequirements.getAddresses()).containsExactly(node.getHostAndPort()); - } - }); - Set taskDescriptorSplitIds = taskDescriptor.getSplits().values().stream() - .map(TestingConnectorSplit::getSplitId) - .collect(toImmutableSet()); - assertThat(taskDescriptorSplitIds).containsAll(replicatedSplitIds); - Set taskDescriptorPartitionedSplitIds = difference(taskDescriptorSplitIds, replicatedSplitIds); - Set taskDescriptorSplitPartitions = new HashSet<>(); - for (Split split : taskDescriptor.getSplits().values()) { - int splitId = TestingConnectorSplit.getSplitId(split); - if (taskDescriptorPartitionedSplitIds.contains(splitId)) { - int splitPartition = partitioningScheme.getPartition(split); - taskDescriptorSplitPartitions.add(splitPartition); - } - } - for (Integer splitPartition : taskDescriptorSplitPartitions) { - assertThat(taskDescriptorPartitionedSplitIds).containsAll(partitionedSplitIds.get(splitPartition)); - } - } + // single splittable source + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 0), createSplit(3, 0)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(3) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(5, 1, 1)))) + .withSplittableSources(PARTITIONED_1) + .withMergeAllowed(true) + .withExpectedTaskCount(2) + .run(); + + // largest source is not splittable + testAssigner() + .withPartitionedSources(PARTITIONED_1) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 0), createSplit(3, 0)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(3) + .withOutputDataSizeEstimates(ImmutableMap.of(PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(5, 1, 1)))) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + // multiple sources + testAssigner() + .withPartitionedSources(PARTITIONED_1, PARTITIONED_2) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 0), createSplit(3, 0)), true), + new SplitBatch(PARTITIONED_2, createSplitMap(createSplit(4, 0), createSplit(5, 1)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(30) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)), + PARTITIONED_2, new OutputDataSizeEstimate(ImmutableLongArray.of(2, 1, 1)))) + .withSplittableSources(PARTITIONED_1) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1, PARTITIONED_2) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 0), createSplit(3, 0)), true), + new SplitBatch(PARTITIONED_2, createSplitMap(createSplit(4, 0), createSplit(5, 1)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(30) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)), + PARTITIONED_2, new OutputDataSizeEstimate(ImmutableLongArray.of(2, 1, 1)))) + .withSplittableSources(PARTITIONED_1, PARTITIONED_2) + .withMergeAllowed(true) + .withExpectedTaskCount(3) + .run(); + testAssigner() + .withPartitionedSources(PARTITIONED_1, PARTITIONED_2) + .withSplits( + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(0, 0), createSplit(1, 0)), false), + new SplitBatch(PARTITIONED_1, createSplitMap(createSplit(2, 0), createSplit(3, 0)), true), + new SplitBatch(PARTITIONED_2, createSplitMap(createSplit(4, 0), createSplit(5, 0)), true)) + .withSplitPartitionCount(3) + .withTargetPartitionSizeInBytes(30) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)), + PARTITIONED_2, new OutputDataSizeEstimate(ImmutableLongArray.of(2, 1, 1)))) + .withSplittableSources(PARTITIONED_2) + .withMergeAllowed(true) + .withExpectedTaskCount(1) + .run(); + } + + @Test + public void testCreateOutputPartitionToTaskPartition() + { + testPartitionMapping() + .withSplitPartitionCount(3) + .withPartitionedSources(PARTITIONED_1) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)))) + .withTargetPartitionSizeInBytes(25) + .withSplittableSources(PARTITIONED_1) + .withMergeAllowed(true) + .withExpectedMappings( + new PartitionMapping(ImmutableSet.of(0), 3), + new PartitionMapping(ImmutableSet.of(1, 2), 1)) + .run(); + testPartitionMapping() + .withSplitPartitionCount(3) + .withPartitionedSources(PARTITIONED_1) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)))) + .withTargetPartitionSizeInBytes(25) + .withMergeAllowed(true) + .withExpectedMappings( + new PartitionMapping(ImmutableSet.of(0), 1), + new PartitionMapping(ImmutableSet.of(1, 2), 1)) + .run(); + testPartitionMapping() + .withSplitPartitionCount(3) + .withPartitionedSources(PARTITIONED_1) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)))) + .withTargetPartitionSizeInBytes(25) + .withMergeAllowed(false) + .withExpectedMappings( + new PartitionMapping(ImmutableSet.of(0), 1), + new PartitionMapping(ImmutableSet.of(1), 1), + new PartitionMapping(ImmutableSet.of(2), 1)) + .run(); + testPartitionMapping() + .withSplitPartitionCount(3) + .withPartitionedSources(PARTITIONED_1) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(50, 1, 1)))) + .withTargetPartitionSizeInBytes(25) + .withMergeAllowed(false) + .withSplittableSources(PARTITIONED_1) + .withExpectedMappings( + new PartitionMapping(ImmutableSet.of(0), 3), + new PartitionMapping(ImmutableSet.of(1), 1), + new PartitionMapping(ImmutableSet.of(2), 1)) + .run(); + testPartitionMapping() + .withSplitPartitionCount(4) + .withPartitionedSources(PARTITIONED_1) + .withOutputDataSizeEstimates(ImmutableMap.of( + PARTITIONED_1, new OutputDataSizeEstimate(ImmutableLongArray.of(0, 0, 0, 60)))) + .withTargetPartitionSizeInBytes(25) + .withMergeAllowed(false) + .withSplittableSources(PARTITIONED_1) + .withExpectedMappings( + new PartitionMapping(ImmutableSet.of(0), 1), + new PartitionMapping(ImmutableSet.of(1), 1), + new PartitionMapping(ImmutableSet.of(2), 1), + new PartitionMapping(ImmutableSet.of(3), 3)) + .run(); } private static ListMultimap createSplitMap(Split... splits) @@ -444,4 +513,262 @@ public boolean isNoMoreSplits() return noMoreSplits; } } + + public static AssignerTester testAssigner() + { + return new AssignerTester(); + } + + private static class AssignerTester + { + private Set partitionedSources = ImmutableSet.of(); + private Set replicatedSources = ImmutableSet.of(); + private List splits = ImmutableList.of(); + private int splitPartitionCount; + private Optional> partitionToNodeMap = Optional.empty(); + private long targetPartitionSizeInBytes; + private Map outputDataSizeEstimates = ImmutableMap.of(); + private Set splittableSources = ImmutableSet.of(); + private boolean mergeAllowed; + private int expectedTaskCount; + + public AssignerTester withPartitionedSources(PlanNodeId... sources) + { + partitionedSources = ImmutableSet.copyOf(sources); + return this; + } + + public AssignerTester withReplicatedSources(PlanNodeId... sources) + { + replicatedSources = ImmutableSet.copyOf(sources); + return this; + } + + public AssignerTester withSplits(SplitBatch... splits) + { + this.splits = ImmutableList.copyOf(splits); + return this; + } + + public AssignerTester withSplitPartitionCount(int splitPartitionCount) + { + this.splitPartitionCount = splitPartitionCount; + return this; + } + + public AssignerTester withPartitionToNodeMap(Optional> partitionToNodeMap) + { + this.partitionToNodeMap = partitionToNodeMap; + return this; + } + + public AssignerTester withTargetPartitionSizeInBytes(long targetPartitionSizeInBytes) + { + this.targetPartitionSizeInBytes = targetPartitionSizeInBytes; + return this; + } + + public AssignerTester withOutputDataSizeEstimates(Map outputDataSizeEstimates) + { + this.outputDataSizeEstimates = outputDataSizeEstimates; + return this; + } + + public AssignerTester withSplittableSources(PlanNodeId... sources) + { + splittableSources = ImmutableSet.copyOf(sources); + return this; + } + + public AssignerTester withMergeAllowed(boolean mergeAllowed) + { + this.mergeAllowed = mergeAllowed; + return this; + } + + public AssignerTester withExpectedTaskCount(int expectedTaskCount) + { + this.expectedTaskCount = expectedTaskCount; + return this; + } + + public void run() + { + FaultTolerantPartitioningScheme partitioningScheme = createPartitioningScheme(splitPartitionCount, partitionToNodeMap); + Map outputPartitionToTaskPartition = createOutputPartitionToTaskPartition( + partitioningScheme, + partitionedSources, + outputDataSizeEstimates, + targetPartitionSizeInBytes, + splittableSources::contains, + mergeAllowed); + HashDistributionSplitAssigner assigner = new HashDistributionSplitAssigner( + Optional.of(TESTING_CATALOG_HANDLE), + partitionedSources, + replicatedSources, + partitioningScheme, + outputPartitionToTaskPartition); + TestingTaskSourceCallback callback = new TestingTaskSourceCallback(); + Map> partitionedSplitIds = new HashMap<>(); + Set replicatedSplitIds = new HashSet<>(); + for (SplitBatch batch : splits) { + assigner.assign(batch.getPlanNodeId(), batch.getSplits(), batch.isNoMoreSplits()).update(callback); + boolean replicated = replicatedSources.contains(batch.getPlanNodeId()); + callback.checkContainsSplits(batch.getPlanNodeId(), batch.getSplits().values(), replicated); + for (Map.Entry entry : batch.getSplits().entries()) { + int splitId = TestingConnectorSplit.getSplitId(entry.getValue()); + if (replicated) { + assertThat(replicatedSplitIds).doesNotContain(splitId); + replicatedSplitIds.add(splitId); + } + else { + partitionedSplitIds.computeIfAbsent(entry.getKey(), key -> ArrayListMultimap.create()).put(batch.getPlanNodeId(), splitId); + } + } + } + assigner.finish().update(callback); + Map taskDescriptors = callback.getTaskDescriptors().stream() + .collect(toImmutableMap(TaskDescriptor::getPartitionId, Function.identity())); + assertThat(taskDescriptors).hasSize(expectedTaskCount); + + // validate node requirements and replicated splits + for (TaskDescriptor taskDescriptor : taskDescriptors.values()) { + int partitionId = taskDescriptor.getPartitionId(); + NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements(); + assertEquals(nodeRequirements.getCatalogHandle(), Optional.of(TESTING_CATALOG_HANDLE)); + partitionToNodeMap.ifPresent(partitionToNode -> { + if (!taskDescriptor.getSplits().isEmpty()) { + InternalNode node = partitionToNode.get(partitionId); + assertThat(nodeRequirements.getAddresses()).containsExactly(node.getHostAndPort()); + } + }); + Set taskDescriptorSplitIds = taskDescriptor.getSplits().values().stream() + .map(TestingConnectorSplit::getSplitId) + .collect(toImmutableSet()); + assertThat(taskDescriptorSplitIds).containsAll(replicatedSplitIds); + } + + // validate partitioned splits + partitionedSplitIds.forEach((partitionId, sourceSplits) -> { + sourceSplits.forEach((source, splitId) -> { + List descriptors = outputPartitionToTaskPartition.get(partitionId).getSubPartitions().stream() + .filter(HashDistributionSplitAssigner.SubPartition::isIdAssigned) + .map(HashDistributionSplitAssigner.SubPartition::getId) + .map(taskDescriptors::get) + .collect(toImmutableList()); + for (TaskDescriptor descriptor : descriptors) { + Set taskDescriptorSplitIds = descriptor.getSplits().values().stream() + .map(TestingConnectorSplit::getSplitId) + .collect(toImmutableSet()); + if (taskDescriptorSplitIds.contains(splitId) && splittableSources.contains(source)) { + return; + } + if (!taskDescriptorSplitIds.contains(splitId) && !splittableSources.contains(source)) { + fail("expected split not found: ." + splitId); + } + } + if (splittableSources.contains(source)) { + fail("expected split not found: ." + splitId); + } + }); + }); + } + } + + private static PartitionMappingTester testPartitionMapping() + { + return new PartitionMappingTester(); + } + + private static class PartitionMappingTester + { + private Set partitionedSources = ImmutableSet.of(); + private int splitPartitionCount; + private Optional> partitionToNodeMap = Optional.empty(); + private long targetPartitionSizeInBytes; + private Map outputDataSizeEstimates = ImmutableMap.of(); + private Set splittableSources = ImmutableSet.of(); + private boolean mergeAllowed; + private Set expectedMappings = ImmutableSet.of(); + + public PartitionMappingTester withPartitionedSources(PlanNodeId... sources) + { + partitionedSources = ImmutableSet.copyOf(sources); + return this; + } + + public PartitionMappingTester withSplitPartitionCount(int splitPartitionCount) + { + this.splitPartitionCount = splitPartitionCount; + return this; + } + + public PartitionMappingTester withPartitionToNodeMap(Optional> partitionToNodeMap) + { + this.partitionToNodeMap = partitionToNodeMap; + return this; + } + + public PartitionMappingTester withTargetPartitionSizeInBytes(long targetPartitionSizeInBytes) + { + this.targetPartitionSizeInBytes = targetPartitionSizeInBytes; + return this; + } + + public PartitionMappingTester withOutputDataSizeEstimates(Map outputDataSizeEstimates) + { + this.outputDataSizeEstimates = outputDataSizeEstimates; + return this; + } + + public PartitionMappingTester withSplittableSources(PlanNodeId... sources) + { + splittableSources = ImmutableSet.copyOf(sources); + return this; + } + + public PartitionMappingTester withMergeAllowed(boolean mergeAllowed) + { + this.mergeAllowed = mergeAllowed; + return this; + } + + public PartitionMappingTester withExpectedMappings(PartitionMapping... mappings) + { + expectedMappings = ImmutableSet.copyOf(mappings); + return this; + } + + public void run() + { + FaultTolerantPartitioningScheme partitioningScheme = createPartitioningScheme(splitPartitionCount, partitionToNodeMap); + Map actual = createOutputPartitionToTaskPartition( + partitioningScheme, + partitionedSources, + outputDataSizeEstimates, + targetPartitionSizeInBytes, + splittableSources::contains, + mergeAllowed); + Set actualGroups = extractMappings(actual); + assertEquals(actualGroups, expectedMappings); + } + + private static Set extractMappings(Map outputPartitionToTaskPartition) + { + SetMultimap grouped = outputPartitionToTaskPartition.entrySet().stream() + .collect(toImmutableSetMultimap(Map.Entry::getValue, Map.Entry::getKey)); + return Multimaps.asMap(grouped).entrySet().stream() + .map(entry -> new PartitionMapping(entry.getValue(), entry.getKey().getSubPartitions().size())) + .collect(toImmutableSet()); + } + } + + @SuppressWarnings("unused") + private record PartitionMapping(Set sourcePartitions, int taskPartitionCount) + { + private PartitionMapping + { + sourcePartitions = ImmutableSet.copyOf(requireNonNull(sourcePartitions, "sourcePartitions is null")); + } + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java index f63afe6d9f9c..f041f9ba2dce 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java @@ -2338,10 +2338,6 @@ private Optional getTableHandleForOptimize(Connecto throw new TrinoException(NOT_SUPPORTED, "OPTIMIZE procedure must be explicitly enabled via " + NON_TRANSACTIONAL_OPTIMIZE_ENABLED + " session property"); } - if (retryMode != NO_RETRIES) { - throw new TrinoException(NOT_SUPPORTED, "OPTIMIZE procedure is not supported with query retries enabled"); - } - HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; SchemaTableName tableName = hiveTableHandle.getSchemaTableName(); @@ -3406,6 +3402,32 @@ public Optional getNewTableLayout(ConnectorSession session multipleWritersPerPartitionSupported)); } + @Override + public Optional getLayoutForTableExecute(ConnectorSession session, ConnectorTableExecuteHandle executeHandle) + { + HiveTableExecuteHandle hiveExecuteHandle = (HiveTableExecuteHandle) executeHandle; + SchemaTableName tableName = hiveExecuteHandle.getSchemaTableName(); + Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) + .orElseThrow(() -> new TableNotFoundException(tableName)); + + if (table.getStorage().getBucketProperty().isPresent()) { + throw new TrinoException(NOT_SUPPORTED, format("Optimizing bucketed Hive table %s is not supported", tableName)); + } + if (isTransactionalTable(table.getParameters())) { + throw new TrinoException(NOT_SUPPORTED, format("Optimizing transactional Hive table %s is not supported", tableName)); + } + + List partitionColumns = table.getPartitionColumns(); + if (partitionColumns.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new ConnectorTableLayout( + partitionColumns.stream() + .map(Column::getName) + .collect(toImmutableList()))); + } + @Override public TableStatisticsMetadata getStatisticsCollectionMetadataForWrite(ConnectorSession session, ConnectorTableMetadata tableMetadata) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index 102c262d3c1f..87ca80865b5c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -1928,6 +1928,7 @@ public void testTargetMaxFileSizePartitioned() .setSystemProperty("task_scale_writers_enabled", "false") .setSystemProperty("scale_writers", "false") .setSystemProperty("redistribute_writes", "false") + .setSystemProperty("use_preferred_write_partitioning", "false") .build(); assertUpdate(session, createTableSql, 1000000); assertThat(computeActual(selectFileInfo).getRowCount()).isEqualTo(3); @@ -1941,6 +1942,7 @@ public void testTargetMaxFileSizePartitioned() .setSystemProperty("task_partitioned_writer_count", "1") // task scale writers should be disabled since we want to write with a single task writer .setSystemProperty("task_scale_writers_enabled", "false") + .setSystemProperty("use_preferred_write_partitioning", "false") .setCatalogSessionProperty("hive", "target_max_file_size", maxSize.toString()) .build(); @@ -7992,7 +7994,9 @@ private Session optimizeEnabledSession() private void insertNationNTimes(String tableName, int times) { - assertUpdate("INSERT INTO " + tableName + "(nationkey, name, regionkey, comment) " + join(" UNION ALL ", nCopies(times, "SELECT * FROM tpch.sf1.nation")), times * 25); + for (int i = 0; i < times; i++) { + assertUpdate("INSERT INTO " + tableName + "(nationkey, name, regionkey, comment) SELECT * FROM tpch.sf1.nation", 25); + } } private void assertNationNTimes(String tableName, int times) diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFaultTolerantExecutionTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFaultTolerantExecutionTest.java new file mode 100644 index 000000000000..3c71aa0ec448 --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/BaseFaultTolerantExecutionTest.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant; + +import io.trino.Session; +import io.trino.testing.AbstractTestQueryFramework; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertEquals; + +public abstract class BaseFaultTolerantExecutionTest + extends AbstractTestQueryFramework +{ + private final String partitioningTablePropertyName; + + protected BaseFaultTolerantExecutionTest(String partitioningTablePropertyName) + { + this.partitioningTablePropertyName = requireNonNull(partitioningTablePropertyName, "partitioningTablePropertyName is null"); + } + + @Test + public void testTableWritePreferredWritePartitioningSkewMitigation() + { + @Language("SQL") String createTableSql = """ + CREATE TABLE test_table_writer_skew_mitigation WITH (%s = ARRAY['returnflag']) AS + SELECT orderkey, partkey, suppkey, linenumber, quantity, extendedprice, discount, tax, linestatus, shipdate, commitdate, receiptdate, shipinstruct, shipmode, comment, returnflag + FROM tpch.sf1.lineitem + WHERE returnflag = 'N' + LIMIT 1000000""".formatted(partitioningTablePropertyName); + @Language("SQL") String selectFileInfo = "SELECT distinct \"$path\" FROM test_table_writer_skew_mitigation"; + + Session session = withSingleWriterPerTask(getSession()); + + // force single writer task to verify there is exactly one writer per task + assertUpdate(withUnlimitedTargetTaskInputSize(session), createTableSql, 1000000); + assertEquals(computeActual(selectFileInfo).getRowCount(), 1); + assertUpdate("DROP TABLE test_table_writer_skew_mitigation"); + + assertUpdate(withDisabledPreferredWritePartitioning(session), createTableSql, 1000000); + int expectedNumberOfFiles = computeActual(selectFileInfo).getRowCount(); + assertUpdate("DROP TABLE test_table_writer_skew_mitigation"); + assertThat(expectedNumberOfFiles).isGreaterThan(1); + + assertUpdate(withEnabledPreferredWritePartitioning(session), createTableSql, 1000000); + int actualNumberOfFiles = computeActual(selectFileInfo).getRowCount(); + assertUpdate("DROP TABLE test_table_writer_skew_mitigation"); + assertEquals(actualNumberOfFiles, expectedNumberOfFiles); + } + + @Test + public void testExecutePreferredWritePartitioningSkewMitigation() + { + @Language("SQL") String createTableSql = """ + CREATE TABLE test_execute_skew_mitigation WITH (%s = ARRAY['returnflag']) AS + SELECT orderkey, partkey, suppkey, linenumber, quantity, extendedprice, discount, tax, linestatus, shipdate, commitdate, receiptdate, shipinstruct, shipmode, comment, returnflag + FROM tpch.sf1.lineitem + WHERE returnflag = 'N' + LIMIT 1000000""".formatted(partitioningTablePropertyName); + assertUpdate(createTableSql, 1000000); + + @Language("SQL") String executeSql = "ALTER TABLE test_execute_skew_mitigation EXECUTE optimize"; + @Language("SQL") String selectFileInfo = "SELECT distinct \"$path\" FROM test_execute_skew_mitigation"; + + Session session = withSingleWriterPerTask(getSession()); + + // force single writer task to verify there is exactly one writer per task + assertUpdate(withUnlimitedTargetTaskInputSize(session), executeSql); + assertEquals(computeActual(selectFileInfo).getRowCount(), 1); + + assertUpdate(withDisabledPreferredWritePartitioning(session), executeSql); + int expectedNumberOfFiles = computeActual(selectFileInfo).getRowCount(); + assertThat(expectedNumberOfFiles) + .withFailMessage("optimize is expected to generate more than a single file per partition") + .isGreaterThan(1); + + assertUpdate(withEnabledPreferredWritePartitioning(session), executeSql); + int actualNumberOfFiles = computeActual(selectFileInfo).getRowCount(); + assertEquals(actualNumberOfFiles, expectedNumberOfFiles); + + // verify no data is lost in process + assertQuery("SELECT count(*) FROM test_execute_skew_mitigation", "SELECT 1000000"); + + assertUpdate("DROP TABLE test_execute_skew_mitigation"); + } + + private static Session withSingleWriterPerTask(Session session) + { + return Session.builder(session) + // one writer per partition per task + .setSystemProperty("task_writer_count", "1") + .setSystemProperty("task_partitioned_writer_count", "1") + .setSystemProperty("task_scale_writers_enabled", "false") + .build(); + } + + private static Session withUnlimitedTargetTaskInputSize(Session session) + { + return Session.builder(session) + .setSystemProperty("fault_tolerant_execution_target_task_input_size", "1PB") + .build(); + } + + private static Session withDisabledPreferredWritePartitioning(Session session) + { + return Session.builder(session) + .setSystemProperty("use_preferred_write_partitioning", "false") + .build(); + } + + private static Session withEnabledPreferredWritePartitioning(Session session) + { + return Session.builder(session) + .setSystemProperty("use_preferred_write_partitioning", "true") + .build(); + } +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/delta/TestDeltaFaultTolerantExecutionTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/delta/TestDeltaFaultTolerantExecutionTest.java new file mode 100644 index 000000000000..ce6f0ab63225 --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/delta/TestDeltaFaultTolerantExecutionTest.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant.delta; + +import com.google.common.collect.ImmutableMap; +import io.trino.faulttolerant.BaseFaultTolerantExecutionTest; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.exchange.filesystem.containers.MinioStorage; +import io.trino.plugin.hive.containers.HiveMinioDataLake; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; + +import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG; +import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.createS3DeltaLakeQueryRunner; +import static io.trino.plugin.exchange.filesystem.containers.MinioStorage.getExchangeManagerProperties; +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestDeltaFaultTolerantExecutionTest + extends BaseFaultTolerantExecutionTest +{ + private static final String SCHEMA = "fte_preferred_write_partitioning"; + private static final String BUCKET_NAME = "test-fte-preferred-write-partitioning-" + randomTableSuffix(); + + public TestDeltaFaultTolerantExecutionTest() + { + super("partitioned_by"); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + HiveMinioDataLake hiveMinioDataLake = closeAfterClass(new HiveMinioDataLake(BUCKET_NAME)); + hiveMinioDataLake.start(); + MinioStorage minioStorage = closeAfterClass(new MinioStorage(BUCKET_NAME)); + minioStorage.start(); + + DistributedQueryRunner runner = createS3DeltaLakeQueryRunner( + DELTA_CATALOG, + SCHEMA, + FaultTolerantExecutionConnectorTestHelper.getExtraProperties(), + ImmutableMap.of(), + ImmutableMap.of("delta.enable-non-concurrent-writes", "true"), + hiveMinioDataLake.getMinioAddress(), + hiveMinioDataLake.getHiveHadoop(), + instance -> { + instance.installPlugin(new FileSystemExchangePlugin()); + instance.loadExchangeManager("filesystem", getExchangeManagerProperties(minioStorage)); + }); + runner.execute(format("CREATE SCHEMA %s WITH (location = 's3://%s/%s')", SCHEMA, BUCKET_NAME, SCHEMA)); + return runner; + } + + @Override + public void testExecutePreferredWritePartitioningSkewMitigation() + { + assertThatThrownBy(super::testExecutePreferredWritePartitioningSkewMitigation) + .hasMessage("optimize is expected to generate more than a single file per partition"); + } +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionConnectorTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionConnectorTest.java index 7924aa86b472..b857aaa2594f 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionConnectorTest.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionConnectorTest.java @@ -27,7 +27,6 @@ import static io.trino.plugin.exchange.filesystem.containers.MinioStorage.getExchangeManagerProperties; import static io.trino.testing.FaultTolerantExecutionConnectorTestHelper.getExtraProperties; import static io.trino.testing.TestingNames.randomNameSuffix; -import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestHiveFaultTolerantExecutionConnectorTest extends BaseHiveConnectorTest @@ -70,34 +69,6 @@ public void testWritersAcrossMultipleWorkersWhenScaleWritersIsEnabled() // Not applicable for fault-tolerant mode. } - @Override - public void testOptimize() - { - assertThatThrownBy(super::testOptimize) - .hasMessageContaining("OPTIMIZE procedure is not supported with query retries enabled"); - } - - @Override - public void testOptimizeWithWriterScaling() - { - assertThatThrownBy(super::testOptimizeWithWriterScaling) - .hasMessageContaining("OPTIMIZE procedure is not supported with query retries enabled"); - } - - @Override - public void testOptimizeWithPartitioning() - { - assertThatThrownBy(super::testOptimizeWithPartitioning) - .hasMessageContaining("OPTIMIZE procedure is not supported with query retries enabled"); - } - - @Override - public void testOptimizeWithBucketing() - { - assertThatThrownBy(super::testOptimizeWithBucketing) - .hasMessageContaining("OPTIMIZE procedure is not supported with query retries enabled"); - } - @Test public void testMaxOutputPartitionCountCheck() { diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionTest.java new file mode 100644 index 000000000000..8f2716111e47 --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveFaultTolerantExecutionTest.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant.hive; + +import io.trino.Session; +import io.trino.faulttolerant.BaseFaultTolerantExecutionTest; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.exchange.filesystem.containers.MinioStorage; +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; + +import static io.trino.plugin.exchange.filesystem.containers.MinioStorage.getExchangeManagerProperties; +import static io.trino.testing.sql.TestTable.randomTableSuffix; + +public class TestHiveFaultTolerantExecutionTest + extends BaseFaultTolerantExecutionTest +{ + public TestHiveFaultTolerantExecutionTest() + { + super("partitioned_by"); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + MinioStorage minioStorage = closeAfterClass(new MinioStorage("test-exchange-spooling-" + randomTableSuffix())); + minioStorage.start(); + + return HiveQueryRunner.builder() + .setExtraProperties(FaultTolerantExecutionConnectorTestHelper.getExtraProperties()) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FileSystemExchangePlugin()); + runner.loadExchangeManager("filesystem", getExchangeManagerProperties(minioStorage)); + }) + .build(); + } + + @Override + protected Session getSession() + { + Session session = super.getSession(); + return Session.builder(session) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "non_transactional_optimize_enabled", "true") + .build(); + } +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergFaultTolerantExecutionTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergFaultTolerantExecutionTest.java new file mode 100644 index 000000000000..fee2f3c09eec --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergFaultTolerantExecutionTest.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant.iceberg; + +import io.trino.faulttolerant.BaseFaultTolerantExecutionTest; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.exchange.filesystem.containers.MinioStorage; +import io.trino.plugin.iceberg.IcebergQueryRunner; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; + +import java.util.Map; + +import static io.trino.plugin.exchange.filesystem.containers.MinioStorage.getExchangeManagerProperties; +import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestIcebergFaultTolerantExecutionTest + extends BaseFaultTolerantExecutionTest +{ + public TestIcebergFaultTolerantExecutionTest() + { + super("partitioning"); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + MinioStorage minioStorage = closeAfterClass(new MinioStorage("test-exchange-spooling-" + randomTableSuffix())); + minioStorage.start(); + + return IcebergQueryRunner.builder() + .setExtraProperties(FaultTolerantExecutionConnectorTestHelper.getExtraProperties()) + .setIcebergProperties(Map.of("iceberg.experimental.extended-statistics.enabled", "true")) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FileSystemExchangePlugin()); + runner.loadExchangeManager("filesystem", getExchangeManagerProperties(minioStorage)); + }) + .build(); + } + + @Override + public void testExecutePreferredWritePartitioningSkewMitigation() + { + assertThatThrownBy(super::testExecutePreferredWritePartitioningSkewMitigation) + .hasMessage("optimize is expected to generate more than a single file per partition"); + } +}