From dc34cda3b59b24a52d6808f1387341768d4371a8 Mon Sep 17 00:00:00 2001
From: Gaurav Sehgal <gaurav.sehgal8297@gmail.com>
Date: Wed, 2 Nov 2022 15:24:47 +0530
Subject: [PATCH 1/4] Add multipleWritersPerPartitionSupported flag in
 ConnectorTableLayout

This new flag helps the engine to identify
whether writer scaling per partition is allowed.
---
 .../optimizations/BeginTableWrite.java        | 12 ++-
 .../sql/planner/plan/TableWriterNode.java     | 92 ++++++++++++++++++-
 .../operator/TestTableWriterOperator.java     |  1 +
 .../sql/planner/TestingWriterTarget.java      |  6 ++
 .../iterative/rule/test/PlanBuilder.java      |  5 +-
 .../TestValidateScaledWritersUsage.java       | 14 +--
 .../spi/connector/ConnectorTableLayout.java   | 15 +++
 .../plugin/blackhole/BlackHoleMetadata.java   |  2 +-
 .../plugin/deltalake/DeltaLakeMetadata.java   |  2 +-
 .../io/trino/plugin/hive/HiveMetadata.java    | 13 ++-
 .../trino/plugin/iceberg/IcebergMetadata.java |  2 +-
 11 files changed, 146 insertions(+), 18 deletions(-)

diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java
index bcf53e0aa941..3f6c0a3be8ae 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java
@@ -288,11 +288,19 @@ private WriterTarget createWriterTarget(WriterTarget target)
             // TODO: we shouldn't need to store the schemaTableName in the handles, but there isn't a good way to pass this around with the current architecture
             if (target instanceof CreateReference) {
                 CreateReference create = (CreateReference) target;
-                return new CreateTarget(metadata.beginCreateTable(session, create.getCatalog(), create.getTableMetadata(), create.getLayout()), create.getTableMetadata().getTable(), target.supportsReportingWrittenBytes(metadata, session));
+                return new CreateTarget(
+                        metadata.beginCreateTable(session, create.getCatalog(), create.getTableMetadata(), create.getLayout()),
+                        create.getTableMetadata().getTable(),
+                        target.supportsReportingWrittenBytes(metadata, session),
+                        target.supportsMultipleWritersPerPartition(metadata, session));
             }
             if (target instanceof InsertReference) {
                 InsertReference insert = (InsertReference) target;
-                return new InsertTarget(metadata.beginInsert(session, insert.getHandle(), insert.getColumns()), metadata.getTableMetadata(session, insert.getHandle()).getTable(), target.supportsReportingWrittenBytes(metadata, session));
+                return new InsertTarget(
+                        metadata.beginInsert(session, insert.getHandle(), insert.getColumns()),
+                        metadata.getTableMetadata(session, insert.getHandle()).getTable(),
+                        target.supportsReportingWrittenBytes(metadata, session),
+                        target.supportsMultipleWritersPerPartition(metadata, session));
             }
             if (target instanceof DeleteTarget) {
                 DeleteTarget delete = (DeleteTarget) target;
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java
index 9aaa921b56a5..0473b70034e3 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java
@@ -207,6 +207,8 @@ public abstract static class WriterTarget
         public abstract String toString();
 
         public abstract boolean supportsReportingWrittenBytes(Metadata metadata, Session session);
+
+        public abstract boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session);
     }
 
     // only used during planning -- will not be serialized
@@ -239,6 +241,12 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
             return metadata.supportsReportingWrittenBytes(session, fullTableName, tableMetadata.getProperties());
         }
 
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return layout.map(tableLayout -> tableLayout.getLayout().isMultipleWritersPerPartitionSupported()).orElse(true);
+        }
+
         public Optional<TableLayout> getLayout()
         {
             return layout;
@@ -262,16 +270,19 @@ public static class CreateTarget
         private final OutputTableHandle handle;
         private final SchemaTableName schemaTableName;
         private final boolean reportingWrittenBytesSupported;
+        private final boolean multipleWritersPerPartitionSupported;
 
         @JsonCreator
         public CreateTarget(
                 @JsonProperty("handle") OutputTableHandle handle,
                 @JsonProperty("schemaTableName") SchemaTableName schemaTableName,
-                @JsonProperty("reportingWrittenBytesSupported") boolean reportingWrittenBytesSupported)
+                @JsonProperty("reportingWrittenBytesSupported") boolean reportingWrittenBytesSupported,
+                @JsonProperty("multipleWritersPerPartitionSupported") boolean multipleWritersPerPartitionSupported)
         {
             this.handle = requireNonNull(handle, "handle is null");
             this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null");
             this.reportingWrittenBytesSupported = reportingWrittenBytesSupported;
+            this.multipleWritersPerPartitionSupported = multipleWritersPerPartitionSupported;
         }
 
         @JsonProperty
@@ -292,6 +303,12 @@ public boolean getReportingWrittenBytesSupported()
             return reportingWrittenBytesSupported;
         }
 
+        @JsonProperty
+        public boolean isMultipleWritersPerPartitionSupported()
+        {
+            return multipleWritersPerPartitionSupported;
+        }
+
         @Override
         public String toString()
         {
@@ -303,6 +320,12 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return reportingWrittenBytesSupported;
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return multipleWritersPerPartitionSupported;
+        }
     }
 
     // only used during planning -- will not be serialized
@@ -339,6 +362,14 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return metadata.supportsReportingWrittenBytes(session, handle);
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return metadata.getInsertLayout(session, handle)
+                    .map(layout -> layout.getLayout().isMultipleWritersPerPartitionSupported())
+                    .orElse(true);
+        }
     }
 
     public static class InsertTarget
@@ -347,16 +378,19 @@ public static class InsertTarget
         private final InsertTableHandle handle;
         private final SchemaTableName schemaTableName;
         private final boolean reportingWrittenBytesSupported;
+        private final boolean multipleWritersPerPartitionSupported;
 
         @JsonCreator
         public InsertTarget(
                 @JsonProperty("handle") InsertTableHandle handle,
                 @JsonProperty("schemaTableName") SchemaTableName schemaTableName,
-                @JsonProperty("reportingWrittenBytesSupported") boolean reportingWrittenBytesSupported)
+                @JsonProperty("reportingWrittenBytesSupported") boolean reportingWrittenBytesSupported,
+                @JsonProperty("multipleWritersPerPartitionSupported") boolean multipleWritersPerPartitionSupported)
         {
             this.handle = requireNonNull(handle, "handle is null");
             this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null");
             this.reportingWrittenBytesSupported = reportingWrittenBytesSupported;
+            this.multipleWritersPerPartitionSupported = multipleWritersPerPartitionSupported;
         }
 
         @JsonProperty
@@ -377,6 +411,12 @@ public boolean getReportingWrittenBytesSupported()
             return reportingWrittenBytesSupported;
         }
 
+        @JsonProperty
+        public boolean isMultipleWritersPerPartitionSupported()
+        {
+            return multipleWritersPerPartitionSupported;
+        }
+
         @Override
         public String toString()
         {
@@ -388,6 +428,12 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return reportingWrittenBytesSupported;
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return multipleWritersPerPartitionSupported;
+        }
     }
 
     public static class RefreshMaterializedViewReference
@@ -430,6 +476,14 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return metadata.supportsReportingWrittenBytes(session, storageTableHandle);
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return metadata.getInsertLayout(session, storageTableHandle)
+                    .map(layout -> layout.getLayout().isMultipleWritersPerPartitionSupported())
+                    .orElse(true);
+        }
     }
 
     public static class RefreshMaterializedViewTarget
@@ -488,6 +542,14 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return metadata.supportsReportingWrittenBytes(session, tableHandle);
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return metadata.getInsertLayout(session, tableHandle)
+                    .map(layout -> layout.getLayout().isMultipleWritersPerPartitionSupported())
+                    .orElse(true);
+        }
     }
 
     public static class DeleteTarget
@@ -534,6 +596,12 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             throw new UnsupportedOperationException();
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            throw new UnsupportedOperationException();
+        }
     }
 
     public static class UpdateTarget
@@ -599,6 +667,12 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             throw new UnsupportedOperationException();
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            throw new UnsupportedOperationException();
+        }
     }
 
     public static class TableExecuteTarget
@@ -662,6 +736,14 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return sourceHandle.map(tableHandle -> metadata.supportsReportingWrittenBytes(session, tableHandle)).orElse(reportingWrittenBytesSupported);
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return metadata.getLayoutForTableExecute(session, executeHandle)
+                    .map(layout -> layout.getLayout().isMultipleWritersPerPartitionSupported())
+                    .orElse(true);
+        }
     }
 
     public static class MergeTarget
@@ -720,6 +802,12 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
         {
             return false;
         }
+
+        @Override
+        public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+        {
+            return false;
+        }
     }
 
     public static class MergeParadigmAndTypes
diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTableWriterOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestTableWriterOperator.java
index 248a5d333aa8..d3295f45d018 100644
--- a/core/trino-main/src/test/java/io/trino/operator/TestTableWriterOperator.java
+++ b/core/trino-main/src/test/java/io/trino/operator/TestTableWriterOperator.java
@@ -296,6 +296,7 @@ private Operator createTableWriterOperator(
                                 new ConnectorTransactionHandle() {},
                                 new ConnectorOutputTableHandle() {}),
                         schemaTableName,
+                        false,
                         false),
                 ImmutableList.of(0),
                 session,
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestingWriterTarget.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestingWriterTarget.java
index ec55ce67ddab..7fbdb80d50f1 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/TestingWriterTarget.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestingWriterTarget.java
@@ -32,4 +32,10 @@ public boolean supportsReportingWrittenBytes(Metadata metadata, Session session)
     {
         return false;
     }
+
+    @Override
+    public boolean supportsMultipleWritersPerPartition(Metadata metadata, Session session)
+    {
+        return false;
+    }
 }
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java
index bdb1e2a1642c..6d08451a526f 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java
@@ -759,7 +759,7 @@ private DeleteTarget deleteTarget(SchemaTableName schemaTableName)
                 schemaTableName);
     }
 
-    public CreateTarget createTarget(CatalogHandle catalogHandle, SchemaTableName schemaTableName, boolean reportingWrittenBytesSupported)
+    public CreateTarget createTarget(CatalogHandle catalogHandle, SchemaTableName schemaTableName, boolean reportingWrittenBytesSupported, boolean multipleWritersPerPartitionSupported)
     {
         OutputTableHandle tableHandle = new OutputTableHandle(
                 catalogHandle,
@@ -769,7 +769,8 @@ public CreateTarget createTarget(CatalogHandle catalogHandle, SchemaTableName sc
         return new CreateTarget(
                 tableHandle,
                 schemaTableName,
-                reportingWrittenBytesSupported);
+                reportingWrittenBytesSupported,
+                multipleWritersPerPartitionSupported);
     }
 
     public TableFinishNode tableUpdate(SchemaTableName schemaTableName, PlanNode updateSource, Symbol updateRowId, List<Symbol> columnsToBeUpdated)
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java
index 1ba64923024e..8175372adac9 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java
@@ -103,7 +103,7 @@ public void testScaledWritersUsedAndTargetSupportsIt()
         PlanNode root = planBuilder.output(
                 outputBuilder -> outputBuilder
                         .source(planBuilder.tableWithExchangeCreate(
-                                planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true),
+                                planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true, true),
                                 tableWriterSource,
                                 symbol,
                                 new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
@@ -125,7 +125,7 @@ public void testScaledWritersUsedAndTargetDoesNotSupportIt()
         PlanNode root = planBuilder.output(
                 outputBuilder -> outputBuilder
                         .source(planBuilder.tableWithExchangeCreate(
-                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false),
+                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false, true),
                                 tableWriterSource,
                                 symbol,
                                 new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
@@ -155,7 +155,7 @@ public void testScaledWritersUsedAndTargetDoesNotSupportItMultipleSourceExchange
         PlanNode root = planBuilder.output(
                 outputBuilder -> outputBuilder
                         .source(planBuilder.tableWithExchangeCreate(
-                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false),
+                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false, true),
                                 tableWriterSource,
                                 symbol,
                                 new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
@@ -185,7 +185,7 @@ public void testScaledWritersUsedAndTargetSupportsItMultipleSourceExchanges()
         PlanNode root = planBuilder.output(
                 outputBuilder -> outputBuilder
                         .source(planBuilder.tableWithExchangeCreate(
-                                planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true),
+                                planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true, true),
                                 tableWriterSource,
                                 symbol,
                                 new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
@@ -207,7 +207,7 @@ public void testScaledWritersUsedAboveTableWriterInThePlanTree()
         PlanNode root = planBuilder.output(
                 outputBuilder -> outputBuilder
                         .source(planBuilder.tableWithExchangeCreate(
-                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false),
+                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false, true),
                                 tableWriterSource,
                                 symbol,
                                 new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
@@ -226,7 +226,7 @@ public void testScaledWritersTwoTableWritersNodes()
                                 ImmutableList.of("column_a"),
                                 Optional.empty(),
                                 Optional.empty(),
-                                planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true),
+                                planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true, true),
                                 planBuilder.exchange(innerExchange ->
                                         innerExchange
                                                 .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
@@ -236,7 +236,7 @@ public void testScaledWritersTwoTableWritersNodes()
         PlanNode root = planBuilder.output(
                 outputBuilder -> outputBuilder
                         .source(planBuilder.tableWithExchangeCreate(
-                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false),
+                                planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false, true),
                                 tableWriterSource,
                                 symbol,
                                 new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableLayout.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableLayout.java
index 8ffadbf10cea..bd02162d36ee 100644
--- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableLayout.java
+++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableLayout.java
@@ -22,11 +22,20 @@ public class ConnectorTableLayout
 {
     private final Optional<ConnectorPartitioningHandle> partitioning;
     private final List<String> partitionColumns;
+    private final boolean multipleWritersPerPartitionSupported;
 
     public ConnectorTableLayout(ConnectorPartitioningHandle partitioning, List<String> partitionColumns)
+    {
+        // Keep the value of multipleWritersPerPartitionSupported false by default if partitioning is present
+        // for backward compatibility.
+        this(partitioning, partitionColumns, false);
+    }
+
+    public ConnectorTableLayout(ConnectorPartitioningHandle partitioning, List<String> partitionColumns, boolean multipleWritersPerPartitionSupported)
     {
         this.partitioning = Optional.of(requireNonNull(partitioning, "partitioning is null"));
         this.partitionColumns = requireNonNull(partitionColumns, "partitionColumns is null");
+        this.multipleWritersPerPartitionSupported = multipleWritersPerPartitionSupported;
     }
 
     /**
@@ -37,6 +46,7 @@ public ConnectorTableLayout(List<String> partitionColumns)
     {
         this.partitioning = Optional.empty();
         this.partitionColumns = requireNonNull(partitionColumns, "partitionColumns is null");
+        this.multipleWritersPerPartitionSupported = true;
     }
 
     public Optional<ConnectorPartitioningHandle> getPartitioning()
@@ -48,4 +58,9 @@ public List<String> getPartitionColumns()
     {
         return partitionColumns;
     }
+
+    public boolean isMultipleWritersPerPartitionSupported()
+    {
+        return multipleWritersPerPartitionSupported;
+    }
 }
diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java
index a7682ec6173b..2c63744aab9a 100644
--- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java
+++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java
@@ -241,7 +241,7 @@ public Optional<ConnectorTableLayout> getNewTableLayout(ConnectorSession connect
             throw new TrinoException(INVALID_TABLE_PROPERTY, "Distribute columns not defined on table: " + undefinedColumns);
         }
 
-        return Optional.of(new ConnectorTableLayout(BlackHolePartitioningHandle.INSTANCE, distributeColumns));
+        return Optional.of(new ConnectorTableLayout(BlackHolePartitioningHandle.INSTANCE, distributeColumns, true));
     }
 
     @Override
diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java
index 07733d89681c..6bfcd44e525a 100644
--- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java
+++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java
@@ -1710,7 +1710,7 @@ private Optional<ConnectorTableLayout> getLayoutForOptimize(DeltaLakeTableExecut
             partitioningColumns.add(columnsByName.get(columnName));
         }
         DeltaLakePartitioningHandle partitioningHandle = new DeltaLakePartitioningHandle(partitioningColumns.build());
-        return Optional.of(new ConnectorTableLayout(partitioningHandle, partitionColumnNames));
+        return Optional.of(new ConnectorTableLayout(partitioningHandle, partitionColumnNames, true));
     }
 
     @Override
diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java
index 603f51066ce0..bcd078142416 100644
--- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java
+++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java
@@ -3296,6 +3296,10 @@ else if (isFullAcidTable(table.getParameters())) {
                 .map(Column::getName)
                 .forEach(partitioningColumns::add);
 
+        // For transactional bucketed tables we don't want to split output files therefore we need to have single writer
+        // per partition.
+        boolean multipleWritersPerPartitionSupported = !isTransactionalTable(table.getParameters());
+
         HivePartitioningHandle partitioningHandle = new HivePartitioningHandle(
                 hiveBucketHandle.get().getBucketingVersion(),
                 hiveBucketHandle.get().getTableBucketCount(),
@@ -3304,7 +3308,7 @@ else if (isFullAcidTable(table.getParameters())) {
                         .collect(toImmutableList()),
                 OptionalInt.of(hiveBucketHandle.get().getTableBucketCount()),
                 !partitionColumns.isEmpty() && isParallelPartitionedBucketedWrites(session));
-        return Optional.of(new ConnectorTableLayout(partitioningHandle, partitioningColumns.build()));
+        return Optional.of(new ConnectorTableLayout(partitioningHandle, partitioningColumns.build(), multipleWritersPerPartitionSupported));
     }
 
     @Override
@@ -3328,6 +3332,10 @@ public Optional<ConnectorTableLayout> getNewTableLayout(ConnectorSession session
             throw new TrinoException(NOT_SUPPORTED, "Writing to bucketed sorted Hive tables is disabled");
         }
 
+        // For transactional bucketed tables we don't want to split output files therefore we need to have single writer
+        // per partition.
+        boolean multipleWritersPerPartitionSupported = !isTransactional(tableMetadata.getProperties()).orElse(false);
+
         List<String> bucketedBy = bucketProperty.get().getBucketedBy();
         Map<String, HiveType> hiveTypeMap = tableMetadata.getColumns().stream()
                 .collect(toMap(ColumnMetadata::getName, column -> toHiveType(column.getType())));
@@ -3343,7 +3351,8 @@ public Optional<ConnectorTableLayout> getNewTableLayout(ConnectorSession session
                 ImmutableList.<String>builder()
                         .addAll(bucketedBy)
                         .addAll(partitionedBy)
-                        .build()));
+                        .build(),
+                multipleWritersPerPartitionSupported));
     }
 
     @Override
diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java
index 2ec8b67852ef..2605c116eb79 100644
--- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java
+++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java
@@ -746,7 +746,7 @@ private Optional<ConnectorTableLayout> getWriteLayout(Schema tableSchema, Partit
             return Optional.of(new ConnectorTableLayout(partitioningColumnNames));
         }
         IcebergPartitioningHandle partitioningHandle = new IcebergPartitioningHandle(toPartitionFields(partitionSpec), partitioningColumns);
-        return Optional.of(new ConnectorTableLayout(partitioningHandle, partitioningColumnNames));
+        return Optional.of(new ConnectorTableLayout(partitioningHandle, partitioningColumnNames, true));
     }
 
     @Override

From 0034b05ee0091710dfde2fc1d29ebb25b9dc73c0 Mon Sep 17 00:00:00 2001
From: Gaurav Sehgal <gaurav.sehgal8297@gmail.com>
Date: Fri, 28 Oct 2022 09:15:17 +0530
Subject: [PATCH 2/4] Add separate method createPartitionPagePreparer

---
 .../operator/exchange/LocalExchange.java      | 23 +++++++++++--------
 1 file changed, 14 insertions(+), 9 deletions(-)

diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
index 6f5f2b693e13..7062ad33dbbf 100644
--- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
+++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
@@ -143,18 +143,10 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalog
                         partitionChannels,
                         partitionChannelTypes,
                         partitionHashChannel);
-                Function<Page, Page> partitionPagePreparer;
-                if (isSystemPartitioning(partitioning)) {
-                    partitionPagePreparer = identity();
-                }
-                else {
-                    int[] partitionChannelsArray = Ints.toArray(partitionChannels);
-                    partitionPagePreparer = page -> page.getColumns(partitionChannelsArray);
-                }
                 return new PartitioningExchanger(
                         buffers,
                         memoryManager,
-                        partitionPagePreparer,
+                        createPartitionPagePreparer(partitioning, partitionChannels),
                         partitionFunction);
             };
         }
@@ -195,6 +187,19 @@ LocalExchangeSource getSource(int partitionIndex)
         return sources.get(partitionIndex);
     }
 
+    private static Function<Page, Page> createPartitionPagePreparer(PartitioningHandle partitioning, List<Integer> partitionChannels)
+    {
+        Function<Page, Page> partitionPagePreparer;
+        if (partitioning.getConnectorHandle() instanceof SystemPartitioningHandle) {
+            partitionPagePreparer = identity();
+        }
+        else {
+            int[] partitionChannelsArray = Ints.toArray(partitionChannels);
+            partitionPagePreparer = page -> page.getColumns(partitionChannelsArray);
+        }
+        return partitionPagePreparer;
+    }
+
     private static PartitionFunction createPartitionFunction(
             NodePartitioningManager nodePartitioningManager,
             Session session,

From cc3bacf6bbd458f539badc2332d67a6e0d83cf6d Mon Sep 17 00:00:00 2001
From: Gaurav Sehgal <gaurav.sehgal8297@gmail.com>
Date: Mon, 24 Oct 2022 01:47:14 +0530
Subject: [PATCH 3/4] Pass physicalWrittenBytes per writer in Local exchange

Previously physicalWrittenBytesSupplier was passed
directly to the LocalExchange constructor, thus
it was hard to map a buffer to its respective writer
physical written bytes. However, with this change we
are passing it directly while operator get its next
buffer.

This refactor will be helpful for scaling partitioned
writes with skewness.
---
 .../operator/exchange/LocalExchange.java      | 24 ++---
 .../exchange/LocalExchangeSourceOperator.java |  2 +-
 .../exchange/LocalMergeSourceOperator.java    |  2 +-
 .../sql/planner/LocalExecutionPlanner.java    |  2 -
 .../operator/exchange/TestLocalExchange.java  | 88 +++++++++----------
 .../io/trino/operator/join/JoinTestUtils.java |  1 -
 .../join/unspilled/JoinTestUtils.java         |  1 -
 7 files changed, 57 insertions(+), 63 deletions(-)

diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
index 7062ad33dbbf..88e3668ae034 100644
--- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
+++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
@@ -13,7 +13,6 @@
  */
 package io.trino.operator.exchange;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.common.primitives.Ints;
 import io.airlift.slice.XxHash64;
@@ -41,6 +40,7 @@
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
@@ -67,6 +67,9 @@ public class LocalExchange
 
     private final LocalExchangeMemoryManager memoryManager;
 
+    // Physical written bytes for each writer in the same order as source buffers
+    private final List<Supplier<Long>> physicalWrittenBytesSuppliers = new CopyOnWriteArrayList<>();
+
     @GuardedBy("this")
     private boolean allSourcesFinished;
 
@@ -92,7 +95,6 @@ public LocalExchange(
             Optional<Integer> partitionHashChannel,
             DataSize maxBufferedBytes,
             BlockTypeOperators blockTypeOperators,
-            Supplier<Long> physicalWrittenBytesSupplier,
             DataSize writerMinSize)
     {
         ImmutableList.Builder<LocalExchangeSource> sources = ImmutableList.builder();
@@ -128,7 +130,14 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
                     buffers,
                     memoryManager,
                     maxBufferedBytes.toBytes(),
-                    physicalWrittenBytesSupplier,
+                    () -> {
+                        // Avoid using stream api for performance reasons
+                        long physicalWrittenBytes = 0;
+                        for (Supplier<Long> physicalWrittenBytesSupplier : physicalWrittenBytesSuppliers) {
+                            physicalWrittenBytes += physicalWrittenBytesSupplier.get();
+                        }
+                        return physicalWrittenBytes;
+                    },
                     writerMinSize);
         }
         else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() ||
@@ -173,20 +182,15 @@ public synchronized LocalExchangeSinkFactory createSinkFactory()
         return newFactory;
     }
 
-    public synchronized LocalExchangeSource getNextSource()
+    public synchronized LocalExchangeSource getNextSource(Supplier<Long> physicalWrittenBytesSupplier)
     {
         checkState(nextSourceIndex < sources.size(), "All operators already created");
         LocalExchangeSource result = sources.get(nextSourceIndex);
+        physicalWrittenBytesSuppliers.add(physicalWrittenBytesSupplier);
         nextSourceIndex++;
         return result;
     }
 
-    @VisibleForTesting
-    LocalExchangeSource getSource(int partitionIndex)
-    {
-        return sources.get(partitionIndex);
-    }
-
     private static Function<Page, Page> createPartitionPagePreparer(PartitioningHandle partitioning, List<Integer> partitionChannels)
     {
         Function<Page, Page> partitionPagePreparer;
diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchangeSourceOperator.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchangeSourceOperator.java
index 57b6153d3b47..3a78a8deb96c 100644
--- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchangeSourceOperator.java
+++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchangeSourceOperator.java
@@ -48,7 +48,7 @@ public Operator createOperator(DriverContext driverContext)
             checkState(!closed, "Factory is already closed");
 
             OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LocalExchangeSourceOperator.class.getSimpleName());
-            return new LocalExchangeSourceOperator(operatorContext, localExchange.getNextSource());
+            return new LocalExchangeSourceOperator(operatorContext, localExchange.getNextSource(driverContext::getPhysicalWrittenDataSize));
         }
 
         @Override
diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalMergeSourceOperator.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalMergeSourceOperator.java
index ccf07a621a6f..19172faf1744 100644
--- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalMergeSourceOperator.java
+++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalMergeSourceOperator.java
@@ -77,7 +77,7 @@ public Operator createOperator(DriverContext driverContext)
             PageWithPositionComparator comparator = orderingCompiler.compilePageWithPositionComparator(types, sortChannels, orderings);
             List<LocalExchangeSource> sources = IntStream.range(0, localExchange.getBufferCount())
                     .boxed()
-                    .map(index -> localExchange.getNextSource())
+                    .map(index -> localExchange.getNextSource(driverContext::getPhysicalWrittenDataSize))
                     .collect(toImmutableList());
             return new LocalMergeSourceOperator(operatorContext, sources, types, comparator);
         }
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
index e1dd28e740e1..55a6eec503df 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
@@ -3533,7 +3533,6 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan
                     Optional.empty(),
                     maxLocalExchangeBufferSize,
                     blockTypeOperators,
-                    context.getTaskContext()::getPhysicalWrittenDataSize,
                     getWriterMinSize(session));
 
             List<Symbol> expectedLayout = node.getInputs().get(0);
@@ -3611,7 +3610,6 @@ else if (context.getDriverInstanceCount().isPresent()) {
                     hashChannel,
                     maxLocalExchangeBufferSize,
                     blockTypeOperators,
-                    context.getTaskContext()::getPhysicalWrittenDataSize,
                     getWriterMinSize(session));
             for (int i = 0; i < node.getSources().size(); i++) {
                 DriverFactoryParameters driverFactoryParameters = driverFactoryParametersList.get(i);
diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java
index d349b292a5ab..f16bbd575f87 100644
--- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java
+++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java
@@ -50,7 +50,6 @@
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Consumer;
-import java.util.function.Supplier;
 
 import static com.google.common.base.Preconditions.checkArgument;
 import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap;
@@ -79,7 +78,6 @@ public class TestLocalExchange
     private static final DataSize LOCAL_EXCHANGE_MAX_BUFFERED_BYTES = DataSize.of(32, DataSize.Unit.MEGABYTE);
     private static final BlockTypeOperators TYPE_OPERATOR_FACTORY = new BlockTypeOperators(new TypeOperators());
     private static final Session SESSION = testSessionBuilder().build();
-    private static final Supplier<Long> PHYSICAL_WRITTEN_BYTES_SUPPLIER = () -> DataSize.of(32, DataSize.Unit.MEGABYTE).toBytes();
     private static final DataSize WRITER_MIN_SIZE = DataSize.of(32, DataSize.Unit.MEGABYTE);
 
     private final ConcurrentMap<CatalogHandle, ConnectorNodePartitioningProvider> partitionManagers = new ConcurrentHashMap<>();
@@ -115,7 +113,6 @@ public void testGatherSingleWriter()
                 Optional.empty(),
                 DataSize.ofBytes(retainedSizeOfPages(99)),
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -125,7 +122,7 @@ public void testGatherSingleWriter()
             LocalExchangeSinkFactory sinkFactory = exchange.createSinkFactory();
             sinkFactory.noMoreSinkFactories();
 
-            LocalExchangeSource source = exchange.getSource(0);
+            LocalExchangeSource source = getNextSource(exchange);
             assertSource(source, 0);
 
             LocalExchangeSink sink = sinkFactory.createSink();
@@ -189,7 +186,6 @@ public void testBroadcast()
                 Optional.empty(),
                 LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -204,10 +200,10 @@ public void testBroadcast()
             assertSinkCanWrite(sinkB);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
             sinkA.addPage(createPage(0));
@@ -278,7 +274,6 @@ public void testRandom()
                 Optional.empty(),
                 LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -291,10 +286,10 @@ public void testRandom()
             assertSinkCanWrite(sink);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
             for (int i = 0; i < 100; i++) {
@@ -318,7 +313,6 @@ public void testRandom()
     @Test
     public void testScaleWriter()
     {
-        AtomicLong physicalWrittenBytes = new AtomicLong(0);
         LocalExchange localExchange = new LocalExchange(
                 nodePartitioningManager,
                 SESSION,
@@ -329,7 +323,6 @@ public void testScaleWriter()
                 Optional.empty(),
                 DataSize.ofBytes(retainedSizeOfPages(4)),
                 TYPE_OPERATOR_FACTORY,
-                physicalWrittenBytes::get,
                 DataSize.ofBytes(retainedSizeOfPages(2)));
 
         run(localExchange, exchange -> {
@@ -342,13 +335,16 @@ public void testScaleWriter()
             assertSinkCanWrite(sink);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            AtomicLong physicalWrittenBytesA = new AtomicLong(0);
+            LocalExchangeSource sourceA = exchange.getNextSource(physicalWrittenBytesA::get);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            AtomicLong physicalWrittenBytesB = new AtomicLong(0);
+            LocalExchangeSource sourceB = exchange.getNextSource(physicalWrittenBytesB::get);
             assertSource(sourceB, 0);
 
-            LocalExchangeSource sourceC = exchange.getSource(2);
+            AtomicLong physicalWrittenBytesC = new AtomicLong(0);
+            LocalExchangeSource sourceC = exchange.getNextSource(physicalWrittenBytesC::get);
             assertSource(sourceC, 0);
 
             sink.addPage(createPage(0));
@@ -358,7 +354,7 @@ public void testScaleWriter()
             assertEquals(sourceC.getBufferInfo().getBufferedPages(), 0);
 
             // writer min file and buffered data size limits are exceeded, so we should see pages in sourceB
-            physicalWrittenBytes.set(retainedSizeOfPages(2));
+            physicalWrittenBytesA.set(retainedSizeOfPages(2));
             sink.addPage(createPage(0));
             assertEquals(sourceA.getBufferInfo().getBufferedPages(), 2);
             assertEquals(sourceB.getBufferInfo().getBufferedPages(), 1);
@@ -368,7 +364,7 @@ public void testScaleWriter()
             assertRemovePage(sourceA, createPage(0));
 
             // no limit is breached, so we should see round-robin distribution across sourceA and sourceB
-            physicalWrittenBytes.set(retainedSizeOfPages(3));
+            physicalWrittenBytesB.set(retainedSizeOfPages(1));
             sink.addPage(createPage(0));
             sink.addPage(createPage(0));
             sink.addPage(createPage(0));
@@ -378,7 +374,8 @@ public void testScaleWriter()
 
             // writer min file and buffered data size limits are exceeded again, but according to
             // round-robin sourceB should receive a page
-            physicalWrittenBytes.set(retainedSizeOfPages(6));
+            physicalWrittenBytesA.set(retainedSizeOfPages(4));
+            physicalWrittenBytesB.set(retainedSizeOfPages(2));
             sink.addPage(createPage(0));
             assertEquals(sourceA.getBufferInfo().getBufferedPages(), 2);
             assertEquals(sourceB.getBufferInfo().getBufferedPages(), 3);
@@ -388,7 +385,7 @@ public void testScaleWriter()
             assertRemoveAllPages(sourceA, createPage(0));
 
             // sourceC should receive a page
-            physicalWrittenBytes.set(retainedSizeOfPages(7));
+            physicalWrittenBytesB.set(retainedSizeOfPages(3));
             sink.addPage(createPage(0));
             assertEquals(sourceA.getBufferInfo().getBufferedPages(), 0);
             assertEquals(sourceB.getBufferInfo().getBufferedPages(), 3);
@@ -399,7 +396,6 @@ public void testScaleWriter()
     @Test
     public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded()
     {
-        AtomicLong physicalWrittenBytes = new AtomicLong(0);
         LocalExchange localExchange = new LocalExchange(
                 nodePartitioningManager,
                 SESSION,
@@ -410,7 +406,6 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded()
                 Optional.empty(),
                 DataSize.ofBytes(retainedSizeOfPages(4)),
                 TYPE_OPERATOR_FACTORY,
-                physicalWrittenBytes::get,
                 DataSize.ofBytes(retainedSizeOfPages(2)));
 
         run(localExchange, exchange -> {
@@ -423,13 +418,13 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded()
             assertSinkCanWrite(sink);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
-            LocalExchangeSource sourceC = exchange.getSource(2);
+            LocalExchangeSource sourceC = getNextSource(exchange);
             assertSource(sourceC, 0);
 
             range(0, 6).forEach(i -> sink.addPage(createPage(0)));
@@ -442,7 +437,6 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded()
     @Test
     public void testNoWriterScalingWhenOnlyWriterMinSizeLimitIsExceeded()
     {
-        AtomicLong physicalWrittenBytes = new AtomicLong(0);
         LocalExchange localExchange = new LocalExchange(
                 nodePartitioningManager,
                 SESSION,
@@ -453,7 +447,6 @@ public void testNoWriterScalingWhenOnlyWriterMinSizeLimitIsExceeded()
                 Optional.empty(),
                 DataSize.ofBytes(retainedSizeOfPages(20)),
                 TYPE_OPERATOR_FACTORY,
-                physicalWrittenBytes::get,
                 DataSize.ofBytes(retainedSizeOfPages(2)));
 
         run(localExchange, exchange -> {
@@ -466,17 +459,18 @@ public void testNoWriterScalingWhenOnlyWriterMinSizeLimitIsExceeded()
             assertSinkCanWrite(sink);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            AtomicLong physicalWrittenBytesA = new AtomicLong(0);
+            LocalExchangeSource sourceA = exchange.getNextSource(physicalWrittenBytesA::get);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
-            LocalExchangeSource sourceC = exchange.getSource(2);
+            LocalExchangeSource sourceC = getNextSource(exchange);
             assertSource(sourceC, 0);
 
             range(0, 8).forEach(i -> sink.addPage(createPage(0)));
-            physicalWrittenBytes.set(retainedSizeOfPages(8));
+            physicalWrittenBytesA.set(retainedSizeOfPages(8));
             sink.addPage(createPage(0));
             assertEquals(sourceA.getBufferInfo().getBufferedPages(), 9);
             assertEquals(sourceB.getBufferInfo().getBufferedPages(), 0);
@@ -497,7 +491,6 @@ public void testPassthrough()
                 Optional.empty(),
                 DataSize.ofBytes(retainedSizeOfPages(1)),
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -512,10 +505,10 @@ public void testPassthrough()
             assertSinkCanWrite(sinkB);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
             sinkA.addPage(createPage(0));
@@ -565,7 +558,6 @@ public void testPartition()
                 Optional.empty(),
                 LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -578,10 +570,10 @@ public void testPartition()
             assertSinkCanWrite(sink);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
             sink.addPage(createPage(0));
@@ -662,7 +654,6 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa
                 Optional.empty(),
                 LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -675,12 +666,12 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa
             assertSinkCanWrite(sink);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(1);
-            assertSource(sourceA, 0);
-
-            LocalExchangeSource sourceB = exchange.getSource(0);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
+            LocalExchangeSource sourceA = getNextSource(exchange);
+            assertSource(sourceA, 0);
+
             Page pageA = SequencePageBuilder.createSequencePage(types, 1, 100, 42);
             sink.addPage(pageA);
 
@@ -714,7 +705,6 @@ public void writeUnblockWhenAllReadersFinish()
                 Optional.empty(),
                 LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -729,10 +719,10 @@ public void writeUnblockWhenAllReadersFinish()
             assertSinkCanWrite(sinkB);
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
             sourceA.finish();
@@ -762,7 +752,6 @@ public void writeUnblockWhenAllReadersFinishAndPagesConsumed()
                 Optional.empty(),
                 DataSize.ofBytes(1),
                 TYPE_OPERATOR_FACTORY,
-                PHYSICAL_WRITTEN_BYTES_SUPPLIER,
                 WRITER_MIN_SIZE);
 
         run(localExchange, exchange -> {
@@ -783,10 +772,10 @@ public void writeUnblockWhenAllReadersFinishAndPagesConsumed()
 
             sinkFactory.close();
 
-            LocalExchangeSource sourceA = exchange.getSource(0);
+            LocalExchangeSource sourceA = getNextSource(exchange);
             assertSource(sourceA, 0);
 
-            LocalExchangeSource sourceB = exchange.getSource(1);
+            LocalExchangeSource sourceB = getNextSource(exchange);
             assertSource(sourceB, 0);
 
             sinkA.addPage(createPage(0));
@@ -828,6 +817,11 @@ private void run(LocalExchange localExchange, Consumer<LocalExchange> test)
         test.accept(localExchange);
     }
 
+    private LocalExchangeSource getNextSource(LocalExchange exchange)
+    {
+        return exchange.getNextSource(() -> DataSize.of(0, DataSize.Unit.MEGABYTE).toBytes());
+    }
+
     private static void assertSource(LocalExchangeSource source, int pageCount)
     {
         LocalExchangeBufferInfo bufferInfo = source.getBufferInfo();
diff --git a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java
index b9e9044623d0..97f2c7c5f125 100644
--- a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java
+++ b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java
@@ -155,7 +155,6 @@ public static BuildSideSetup setupBuildSide(
                 buildPages.getHashChannel(),
                 DataSize.of(32, DataSize.Unit.MEGABYTE),
                 TYPE_OPERATOR_FACTORY,
-                taskContext::getPhysicalWrittenDataSize,
                 DataSize.of(32, DataSize.Unit.MEGABYTE));
 
         // collect input data into the partitioned exchange
diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java
index a9969a647b22..a491399917f0 100644
--- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java
+++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java
@@ -151,7 +151,6 @@ public static BuildSideSetup setupBuildSide(
                 buildPages.getHashChannel(),
                 DataSize.of(32, DataSize.Unit.MEGABYTE),
                 TYPE_OPERATOR_FACTORY,
-                taskContext::getPhysicalWrittenDataSize,
                 DataSize.of(32, DataSize.Unit.MEGABYTE));
 
         // collect input data into the partitioned exchange

From c3796862fddc51e8050b6322074355634b940744 Mon Sep 17 00:00:00 2001
From: Gaurav Sehgal <gaurav.sehgal8297@gmail.com>
Date: Tue, 8 Nov 2022 14:15:11 +0530
Subject: [PATCH 4/4] Add scaleWriters flag in PartitioningHandle

Instead of using SystemPartitioningHandle for
scale writers partitioning, introduce a separate
flag inside PartitioningHandle. This will eventually
help us enable scale writers for any kind of
partitioning like a connector specific partitioning.
---
 .../EventDrivenTaskSourceFactory.java         |  4 +--
 .../scheduler/PipelinedQueryScheduler.java    |  8 ++---
 .../scheduler/StageTaskSourceFactory.java     |  4 +--
 .../operator/exchange/LocalExchange.java      |  6 ++--
 .../sql/planner/LocalExecutionPlanner.java    |  6 ++--
 .../trino/sql/planner/PartitioningHandle.java | 32 ++++++++++++++++---
 .../sql/planner/SystemPartitioningHandle.java | 10 ++++--
 .../planner/optimizations/AddExchanges.java   |  4 +--
 .../optimizations/AddLocalExchanges.java      |  4 +--
 .../sanity/ValidateScaledWritersUsage.java    |  4 +--
 .../operator/exchange/TestLocalExchange.java  |  8 ++---
 .../TestAddExchangesScaledWriters.java        |  8 ++---
 ...tAddLocalExchangesForTaskScaleWriters.java |  4 +--
 .../TestValidateScaledWritersUsage.java       | 14 ++++----
 14 files changed, 71 insertions(+), 45 deletions(-)

diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java
index 0fdb967d274e..c8556526de35 100644
--- a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java
+++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenTaskSourceFactory.java
@@ -52,7 +52,7 @@
 import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
 import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE;
@@ -176,7 +176,7 @@ private SplitAssigner createSplitAssigner(
                             .addAll(replicatedSources)
                             .build());
         }
-        if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION) || partitioning.equals(SOURCE_DISTRIBUTION)) {
+        if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION) || partitioning.equals(SOURCE_DISTRIBUTION)) {
             return new ArbitraryDistributionSplitAssigner(
                     partitioning.getCatalogHandle(),
                     partitionedSources,
diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java
index 7f0137c4b513..fc05d82fbab9 100644
--- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java
+++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java
@@ -135,7 +135,7 @@
 import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE;
 import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
 import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
 import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE;
@@ -952,7 +952,7 @@ private static Optional<int[]> getBucketToPartition(
                 PlanNode fragmentRoot,
                 List<RemoteSourceNode> remoteSourceNodes)
         {
-            if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
+            if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
                 return Optional.of(new int[1]);
             }
             if (searchFrom(fragmentRoot).where(node -> node instanceof TableScanNode).findFirst().isPresent()) {
@@ -986,7 +986,7 @@ private static Map<PlanFragmentId, PipelinedOutputBufferManager> createOutputBuf
                     if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) {
                         outputBufferManager = new BroadcastPipelinedOutputBufferManager();
                     }
-                    else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
+                    else if (partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
                         outputBufferManager = new ScaledPipelinedOutputBufferManager();
                     }
                     else {
@@ -1058,7 +1058,7 @@ public void stateChanged(QueryState newState)
                         () -> childStageExecutions.stream().anyMatch(StageExecution::isAnyTaskBlocked));
             }
 
-            if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
+            if (partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
                 Supplier<Collection<TaskStatus>> sourceTasksProvider = () -> childStageExecutions.stream()
                         .map(StageExecution::getTaskStatuses)
                         .flatMap(List::stream)
diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java
index c80e12f77d0f..d4d22ad896e4 100644
--- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java
+++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java
@@ -96,7 +96,7 @@
 import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
 import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE;
@@ -160,7 +160,7 @@ public TaskSource create(
         if (partitioning.equals(SINGLE_DISTRIBUTION) || partitioning.equals(COORDINATOR_DISTRIBUTION)) {
             return SingleDistributionTaskSource.create(fragment, exchangeSourceHandles, nodeManager, partitioning.equals(COORDINATOR_DISTRIBUTION));
         }
-        if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
+        if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
             return ArbitraryDistributionTaskSource.create(
                     fragment,
                     exchangeSourceHandles,
diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
index 88e3668ae034..330c8df478ff 100644
--- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
+++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
@@ -53,7 +53,7 @@
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static java.util.Objects.requireNonNull;
 import static java.util.function.Function.identity;
@@ -125,7 +125,7 @@ else if (partitioning.equals(FIXED_PASSTHROUGH_DISTRIBUTION)) {
                 return new PassthroughExchanger(sourceIterator.next(), maxBufferedBytes.toBytes() / bufferCount, memoryManager::updateMemoryUsage);
             };
         }
-        else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
+        else if (partitioning.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
             exchangerSupplier = () -> new ScaleWriterExchanger(
                     buffers,
                     memoryManager,
@@ -367,7 +367,7 @@ else if (partitioning.equals(FIXED_PASSTHROUGH_DISTRIBUTION)) {
             bufferCount = defaultConcurrency;
             checkArgument(partitionChannels.isEmpty(), "Passthrough exchange must not have partition channels");
         }
-        else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
+        else if (partitioning.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
             // Even when scale writers is enabled, the buffer count or the number of drivers will remain constant.
             // However, only some of them are actively doing the work.
             bufferCount = defaultConcurrency;
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
index 55a6eec503df..89a3d11e85a6 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
@@ -342,7 +342,7 @@
 import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
 import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
@@ -516,7 +516,7 @@ public LocalExecutionPlan plan(
 
         if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION) ||
                 partitioningScheme.getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) ||
-                partitioningScheme.getPartitioning().getHandle().equals(SCALED_WRITER_DISTRIBUTION) ||
+                partitioningScheme.getPartitioning().getHandle().equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION) ||
                 partitioningScheme.getPartitioning().getHandle().equals(SINGLE_DISTRIBUTION) ||
                 partitioningScheme.getPartitioning().getHandle().equals(COORDINATOR_DISTRIBUTION)) {
             return plan(taskContext, plan, outputLayout, types, partitionedSourceOrder, new TaskOutputFactory(outputBuffer));
@@ -3506,7 +3506,7 @@ private boolean isLocalScaledWriterExchange(PlanNode node)
 
             return result.isPresent()
                     && result.get() instanceof ExchangeNode
-                    && ((ExchangeNode) result.get()).getPartitioningScheme().getPartitioning().getHandle().equals(SCALED_WRITER_DISTRIBUTION);
+                    && ((ExchangeNode) result.get()).getPartitioningScheme().getPartitioning().getHandle().isScaleWriters();
         }
 
         private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlanContext context)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningHandle.java
index 5127bace8194..af00e831ff72 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningHandle.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningHandle.java
@@ -30,17 +30,28 @@ public class PartitioningHandle
     private final Optional<CatalogHandle> catalogHandle;
     private final Optional<ConnectorTransactionHandle> transactionHandle;
     private final ConnectorPartitioningHandle connectorHandle;
+    private final boolean scaleWriters;
+
+    public PartitioningHandle(
+            Optional<CatalogHandle> catalogHandle,
+            Optional<ConnectorTransactionHandle> transactionHandle,
+            ConnectorPartitioningHandle connectorHandle)
+    {
+        this(catalogHandle, transactionHandle, connectorHandle, false);
+    }
 
     @JsonCreator
     public PartitioningHandle(
             @JsonProperty("catalogHandle") Optional<CatalogHandle> catalogHandle,
             @JsonProperty("transactionHandle") Optional<ConnectorTransactionHandle> transactionHandle,
-            @JsonProperty("connectorHandle") ConnectorPartitioningHandle connectorHandle)
+            @JsonProperty("connectorHandle") ConnectorPartitioningHandle connectorHandle,
+            @JsonProperty("scaleWriters") boolean scaleWriters)
     {
         this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null");
         this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null");
         checkArgument(catalogHandle.isEmpty() || transactionHandle.isPresent(), "transactionHandle is required when catalogHandle is present");
         this.connectorHandle = requireNonNull(connectorHandle, "connectorHandle is null");
+        this.scaleWriters = scaleWriters;
     }
 
     @JsonProperty
@@ -61,6 +72,12 @@ public ConnectorPartitioningHandle getConnectorHandle()
         return connectorHandle;
     }
 
+    @JsonProperty
+    public boolean isScaleWriters()
+    {
+        return scaleWriters;
+    }
+
     public boolean isSingleNode()
     {
         return connectorHandle.isSingleNode();
@@ -84,21 +101,26 @@ public boolean equals(Object o)
 
         return Objects.equals(catalogHandle, that.catalogHandle) &&
                 Objects.equals(transactionHandle, that.transactionHandle) &&
-                Objects.equals(connectorHandle, that.connectorHandle);
+                Objects.equals(connectorHandle, that.connectorHandle) &&
+                scaleWriters == that.scaleWriters;
     }
 
     @Override
     public int hashCode()
     {
-        return Objects.hash(catalogHandle, transactionHandle, connectorHandle);
+        return Objects.hash(catalogHandle, transactionHandle, connectorHandle, scaleWriters);
     }
 
     @Override
     public String toString()
     {
+        String result = connectorHandle.toString();
+        if (scaleWriters) {
+            result = result + " (scale writers)";
+        }
         if (catalogHandle.isPresent()) {
-            return catalogHandle.get() + ":" + connectorHandle;
+            result = catalogHandle.get() + ":" + result;
         }
-        return connectorHandle.toString();
+        return result;
     }
 }
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java
index fe9f8a389a1b..90cfbe712780 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java
@@ -41,7 +41,6 @@ enum SystemPartitioning
         SINGLE,
         FIXED,
         SOURCE,
-        SCALED,
         COORDINATOR_ONLY,
         ARBITRARY
     }
@@ -51,7 +50,7 @@ enum SystemPartitioning
     public static final PartitioningHandle FIXED_HASH_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.FIXED, SystemPartitionFunction.HASH);
     public static final PartitioningHandle FIXED_ARBITRARY_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.FIXED, SystemPartitionFunction.ROUND_ROBIN);
     public static final PartitioningHandle FIXED_BROADCAST_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.FIXED, SystemPartitionFunction.BROADCAST);
-    public static final PartitioningHandle SCALED_WRITER_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.SCALED, SystemPartitionFunction.ROUND_ROBIN);
+    public static final PartitioningHandle SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION = createScaledWriterSystemPartitioning(SystemPartitionFunction.ROUND_ROBIN);
     public static final PartitioningHandle SOURCE_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.SOURCE, SystemPartitionFunction.UNKNOWN);
     public static final PartitioningHandle ARBITRARY_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.ARBITRARY, SystemPartitionFunction.UNKNOWN);
     public static final PartitioningHandle FIXED_PASSTHROUGH_DISTRIBUTION = createSystemPartitioning(SystemPartitioning.FIXED, SystemPartitionFunction.UNKNOWN);
@@ -61,6 +60,11 @@ private static PartitioningHandle createSystemPartitioning(SystemPartitioning pa
         return new PartitioningHandle(Optional.empty(), Optional.empty(), new SystemPartitioningHandle(partitioning, function));
     }
 
+    private static PartitioningHandle createScaledWriterSystemPartitioning(SystemPartitionFunction function)
+    {
+        return new PartitioningHandle(Optional.empty(), Optional.empty(), new SystemPartitioningHandle(SystemPartitioning.ARBITRARY, function), true);
+    }
+
     private final SystemPartitioning partitioning;
     private final SystemPartitionFunction function;
 
@@ -125,7 +129,7 @@ public int hashCode()
     @Override
     public String toString()
     {
-        if (partitioning == SystemPartitioning.FIXED) {
+        if (partitioning == SystemPartitioning.FIXED || partitioning == SystemPartitioning.ARBITRARY) {
             return function.toString();
         }
         return partitioning.toString();
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java
index d0584ef65aac..bebc24758799 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java
@@ -106,7 +106,7 @@
 import static io.trino.sql.planner.FragmentTableScanCounter.hasMultipleSources;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.optimizations.ActualProperties.Global.partitionedOn;
 import static io.trino.sql.planner.optimizations.ActualProperties.Global.singleStreamPartition;
@@ -644,7 +644,7 @@ private PlanWithProperties getWriterPlanWithProperties(Optional<PartitioningSche
         {
             if (partitioningScheme.isEmpty()) {
                 if (scaleWriters && writerTarget.supportsReportingWrittenBytes(plannerContext.getMetadata(), session)) {
-                    partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols()));
+                    partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols()));
                 }
                 else if (redistributeWrites) {
                     partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols()));
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java
index e9d424bd7dbe..c9513b7293e4 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java
@@ -84,7 +84,7 @@
 import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.optimizations.StreamPreferredProperties.any;
 import static io.trino.sql.planner.optimizations.StreamPreferredProperties.defaultParallelism;
@@ -626,7 +626,7 @@ private PlanWithProperties visitUnpartitionedWriter(PlanNode node, PlanNode sour
                                 LOCAL,
                                 newSource.getNode(),
                                 new PartitioningScheme(
-                                        Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()),
+                                        Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()),
                                         newSource.getNode().getOutputSymbols())),
                         newSource.getProperties());
                 return rebaseAndDeriveProperties(node, ImmutableList.of(exchange));
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateScaledWritersUsage.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateScaledWritersUsage.java
index 55387e4cf381..1fe7702852da 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateScaledWritersUsage.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateScaledWritersUsage.java
@@ -31,7 +31,7 @@
 import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkState;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static java.util.Objects.requireNonNull;
 
 /**
@@ -75,7 +75,7 @@ protected List<PartitioningHandle> visitPlan(PlanNode node, Void context)
         public List<PartitioningHandle> visitTableWriter(TableWriterNode node, Void context)
         {
             List<PartitioningHandle> children = collectPartitioningHandles(node.getSources());
-            boolean anyScaledWriterDistribution = children.stream().anyMatch(partitioningHandle -> partitioningHandle == SCALED_WRITER_DISTRIBUTION);
+            boolean anyScaledWriterDistribution = children.stream().anyMatch(partitioningHandle -> partitioningHandle == SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION);
             TableWriterNode.WriterTarget target = node.getTarget();
             checkState(!anyScaledWriterDistribution || target.supportsReportingWrittenBytes(plannerContext.getMetadata(), session),
                     "The partitioning scheme is set to SCALED_WRITER_DISTRIBUTION but writer target %s does support for it", target);
diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java
index f16bbd575f87..dacde01aa24a 100644
--- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java
+++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java
@@ -59,7 +59,7 @@
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE;
 import static io.trino.testing.TestingSession.testSessionBuilder;
@@ -317,7 +317,7 @@ public void testScaleWriter()
                 nodePartitioningManager,
                 SESSION,
                 3,
-                SCALED_WRITER_DISTRIBUTION,
+                SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION,
                 ImmutableList.of(),
                 ImmutableList.of(),
                 Optional.empty(),
@@ -400,7 +400,7 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded()
                 nodePartitioningManager,
                 SESSION,
                 3,
-                SCALED_WRITER_DISTRIBUTION,
+                SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION,
                 ImmutableList.of(),
                 ImmutableList.of(),
                 Optional.empty(),
@@ -441,7 +441,7 @@ public void testNoWriterScalingWhenOnlyWriterMinSizeLimitIsExceeded()
                 nodePartitioningManager,
                 SESSION,
                 3,
-                SCALED_WRITER_DISTRIBUTION,
+                SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION,
                 ImmutableList.of(),
                 ImmutableList.of(),
                 Optional.empty(),
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesScaledWriters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesScaledWriters.java
index ba9fa4415571..fe93273a64ed 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesScaledWriters.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddExchangesScaledWriters.java
@@ -26,7 +26,7 @@
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.testing.TestingSession.testSessionBuilder;
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -74,11 +74,11 @@ public void testScaledWritersEnabled(boolean isScaleWritersEnabled)
         String query = "CREATE TABLE mock_report_written_bytes.mock.test AS SELECT * FROM tpch.tiny.nation";
         SubPlan subPlan = subplan(query, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false, session);
         if (isScaleWritersEnabled) {
-            assertThat(subPlan.getAllFragments().get(1).getPartitioning().getConnectorHandle()).isEqualTo(SCALED_WRITER_DISTRIBUTION.getConnectorHandle());
+            assertThat(subPlan.getAllFragments().get(1).getPartitioning().getConnectorHandle()).isEqualTo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION.getConnectorHandle());
         }
         else {
             subPlan.getAllFragments().forEach(
-                    fragment -> assertThat(fragment.getPartitioning().getConnectorHandle()).isNotEqualTo(SCALED_WRITER_DISTRIBUTION.getConnectorHandle()));
+                    fragment -> assertThat(fragment.getPartitioning().getConnectorHandle()).isNotEqualTo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION.getConnectorHandle()));
         }
     }
 
@@ -93,6 +93,6 @@ public void testScaledWritersDisabled(boolean isScaleWritersEnabled)
         String query = "CREATE TABLE mock_dont_report_written_bytes.mock.test AS SELECT * FROM tpch.tiny.nation";
         SubPlan subPlan = subplan(query, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false, session);
         subPlan.getAllFragments().forEach(
-                fragment -> assertThat(fragment.getPartitioning().getConnectorHandle()).isNotEqualTo(SCALED_WRITER_DISTRIBUTION.getConnectorHandle()));
+                fragment -> assertThat(fragment.getPartitioning().getConnectorHandle()).isNotEqualTo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION.getConnectorHandle()));
     }
 }
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddLocalExchangesForTaskScaleWriters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddLocalExchangesForTaskScaleWriters.java
index db85f3d20893..60373777ee3d 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddLocalExchangesForTaskScaleWriters.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestAddLocalExchangesForTaskScaleWriters.java
@@ -26,7 +26,7 @@
 import static io.trino.SystemSessionProperties.SCALE_WRITERS;
 import static io.trino.SystemSessionProperties.TASK_SCALE_WRITERS_ENABLED;
 import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
 import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
@@ -77,7 +77,7 @@ public void testLocalScaledWriterDistributionWithSupportsReportingWrittenBytes()
                         tableWriter(
                                 ImmutableList.of("nationkey"),
                                 ImmutableList.of("nationkey"),
-                                exchange(LOCAL, REPARTITION, SCALED_WRITER_DISTRIBUTION,
+                                exchange(LOCAL, REPARTITION, SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION,
                                         exchange(REMOTE, REPARTITION, FIXED_ARBITRARY_DISTRIBUTION,
                                                 tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))));
 
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java
index 8175372adac9..65b52ac77328 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateScaledWritersUsage.java
@@ -41,7 +41,7 @@
 
 import static io.trino.SessionTestUtils.TEST_SESSION;
 import static io.trino.spi.type.BigintType.BIGINT;
-import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION;
 import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
 import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer;
 import static io.trino.testing.TestingHandles.createTestCatalogHandle;
@@ -97,7 +97,7 @@ public void testScaledWritersUsedAndTargetSupportsIt()
                         .addInputsSet(symbol)
                         .addSource(planBuilder.exchange(innerExchange ->
                                 innerExchange
-                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
+                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
                                         .addInputsSet(symbol)
                                         .addSource(tableScanNode))));
         PlanNode root = planBuilder.output(
@@ -119,7 +119,7 @@ public void testScaledWritersUsedAndTargetDoesNotSupportIt()
                         .addInputsSet(symbol)
                         .addSource(planBuilder.exchange(innerExchange ->
                                 innerExchange
-                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
+                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
                                         .addInputsSet(symbol)
                                         .addSource(tableScanNode))));
         PlanNode root = planBuilder.output(
@@ -144,7 +144,7 @@ public void testScaledWritersUsedAndTargetDoesNotSupportItMultipleSourceExchange
                         .addInputsSet(symbol, symbol)
                         .addSource(planBuilder.exchange(innerExchange ->
                                 innerExchange
-                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
+                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
                                         .addInputsSet(symbol)
                                         .addSource(tableScanNode)))
                         .addSource(planBuilder.exchange(innerExchange ->
@@ -174,7 +174,7 @@ public void testScaledWritersUsedAndTargetSupportsItMultipleSourceExchanges()
                         .addInputsSet(symbol, symbol)
                         .addSource(planBuilder.exchange(innerExchange ->
                                 innerExchange
-                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
+                                        .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
                                         .addInputsSet(symbol)
                                         .addSource(tableScanNode)))
                         .addSource(planBuilder.exchange(innerExchange ->
@@ -210,7 +210,7 @@ public void testScaledWritersUsedAboveTableWriterInThePlanTree()
                                 planBuilder.createTarget(catalogNotSupportingScaledWriters, schemaTableName, false, true),
                                 tableWriterSource,
                                 symbol,
-                                new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
+                                new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))));
         validatePlan(root);
     }
 
@@ -229,7 +229,7 @@ public void testScaledWritersTwoTableWritersNodes()
                                 planBuilder.createTarget(catalogSupportingScaledWriters, schemaTableName, true, true),
                                 planBuilder.exchange(innerExchange ->
                                         innerExchange
-                                                .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
+                                                .partitioningScheme(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)))
                                                 .addInputsSet(symbol)
                                                 .addSource(tableScanNode)),
                                 symbol)));