From 4c2e3d8882eb314c0d68d44d8ab0599e75fcb284 Mon Sep 17 00:00:00 2001 From: zjuwangg Date: Wed, 6 Mar 2024 21:17:13 +0800 Subject: [PATCH] [GLUTEN-4875][VL]Support spark sql conf sortBeforeRepartition to avoid stage partial retry casuing result mismatch --- .../velox/SparkPlanExecApiImpl.scala | 23 +++++++++++++++---- .../execution/TestOperator.scala | 23 +++++++++++++++++++ cpp/core/jni/JniWrapper.cc | 4 +++- cpp/core/shuffle/Options.h | 1 + cpp/core/shuffle/Partitioner.cc | 4 ++-- cpp/core/shuffle/Partitioner.h | 2 +- cpp/core/shuffle/RoundRobinPartitioner.h | 4 ++-- cpp/velox/shuffle/VeloxShuffleWriter.cc | 3 ++- ...lebornHashBasedColumnarShuffleWriter.scala | 4 ++++ ...lebornHashBasedColumnarShuffleWriter.scala | 3 ++- .../vectorized/ShuffleWriterJniWrapper.java | 15 ++++++++---- .../spark/shuffle/ColumnarShuffleWriter.scala | 7 +++++- 12 files changed, 75 insertions(+), 18 deletions(-) diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala index 713befb84883..8901d6aa950a 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala @@ -36,14 +36,14 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, Round, StringSplit, StringTrim} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, Round, SortOrder, StringSplit, StringTrim} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec} +import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SortExec, SparkPlan, VeloxColumnarWriteFilesExec} import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -215,7 +215,22 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { TransformHints.tagNotTransformable(shuffle, validationResult) shuffle.withNewChildren(newChild :: Nil) } - + case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 => + val hashExpr = new Murmur3Hash(newChild.output) + val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output + val projectTransformer = ProjectExecTransformer(projectList, newChild) + val sortOrder = SortOrder(projectList.head, Ascending) + val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = false, projectTransformer) + val projectValidationResult = projectTransformer.doValidate() + val sortValidationResult = sortByHashCode.doValidate() + if (projectValidationResult.isValid && sortValidationResult.isValid) { + ColumnarShuffleExchangeExec(shuffle, sortByHashCode, sortByHashCode.output.drop(1)) + } else { + TransformHints.tagNotTransformable( + shuffle, + if (projectValidationResult.isValid) sortValidationResult else projectValidationResult) + shuffle.withNewChildren(newChild :: Nil) + } case _ => ColumnarShuffleExchangeExec(shuffle, newChild, null) } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index 831a6014f0ba..79d0ce23d444 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -1156,4 +1156,27 @@ class TestOperator extends VeloxWholeStageTransformerSuite { checkOperatorMatch[HashAggregateExecTransformer] } } + + test("test roundrobine with sort") { + // scalastyle:off + runQueryAndCompare("SELECT /*+ REPARTITION(3) */ * FROM lineitem") { + /* + ColumnarExchange RoundRobinPartitioning(3), REPARTITION_BY_NUM, [l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31], [id=#131], [id=#131], [OUTPUT] List(l_orderkey:LongType, l_partkey:LongType, l_suppkey:LongType, l_linenumber:IntegerType, l_quantity:DecimalType(12,2), l_extendedprice:DecimalType(12,2), l_discount:DecimalType(12,2), l_tax:DecimalType(12,2), l_returnflag:StringType, l_linestatus:StringType, l_shipdate:DateType, l_commitdate:DateType, l_receiptdate:DateType, l_shipinstruct:StringType, l_shipmode:StringType, l_comment:StringType), [OUTPUT] List(l_orderkey:LongType, l_partkey:LongType, l_suppkey:LongType, l_linenumber:IntegerType, l_quantity:DecimalType(12,2), l_extendedprice:DecimalType(12,2), l_discount:DecimalType(12,2), l_tax:DecimalType(12,2), l_returnflag:StringType, l_linestatus:StringType, l_shipdate:DateType, l_commitdate:DateType, l_receiptdate:DateType, l_shipinstruct:StringType, l_shipmode:StringType, l_comment:StringType) + +- ^(2) SortExecTransformer [hash_partition_key#302 ASC NULLS FIRST], false, 0 + +- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31, 42) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31] + +- ^(2) BatchScanExecTransformer[l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[file:/home/wanggang.terry/gluten/backends-velox/target/scala-2.12/test..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct "false") { + runQueryAndCompare("SELECT /*+ REPARTITION(3) */ * FROM lineitem") { + df => + { + assert(getExecutedPlan(df).count(_.isInstanceOf[SortExecTransformer]) == 0) + } + } + } + } } diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index 17a07aa09197..8274198fb4b6 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -830,7 +830,8 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper jint startPartitionId, jint pushBufferMaxSize, jobject partitionPusher, - jstring partitionWriterTypeJstr) { + jstring partitionWriterTypeJstr, + jboolean sortBeforeRR) { JNI_METHOD_START auto ctx = gluten::getRuntime(env, wrapper); auto memoryManager = jniCastOrThrow(memoryManagerHandle); @@ -844,6 +845,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper .partitioning = gluten::toPartitioning(jStringToCString(env, partitioningNameJstr)), .taskAttemptId = (int64_t)taskAttemptId, .startPartitionId = startPartitionId, + .sortBeforeRepartition = sortBeforeRR, }; jclass cls = env->FindClass("java/lang/Thread"); diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h index 6531f150ede1..443ad3d26cb3 100644 --- a/cpp/core/shuffle/Options.h +++ b/cpp/core/shuffle/Options.h @@ -48,6 +48,7 @@ struct ShuffleWriterOptions { int64_t taskAttemptId = -1; int32_t startPartitionId = 0; int64_t threadId = -1; + bool sortBeforeRR = true; }; struct PartitionWriterOptions { diff --git a/cpp/core/shuffle/Partitioner.cc b/cpp/core/shuffle/Partitioner.cc index 80b4598a1f17..ded889962bc2 100644 --- a/cpp/core/shuffle/Partitioner.cc +++ b/cpp/core/shuffle/Partitioner.cc @@ -24,12 +24,12 @@ namespace gluten { arrow::Result> -Partitioner::make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId) { +Partitioner::make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId, bool sortBeforeRR)) { switch (partitioning) { case Partitioning::kHash: return std::make_shared(numPartitions); case Partitioning::kRoundRobin: - return std::make_shared(numPartitions, startPartitionId); + return std::make_shared(numPartitions, startPartitionId, sortBeforeRR); case Partitioning::kSingle: return std::make_shared(); case Partitioning::kRange: diff --git a/cpp/core/shuffle/Partitioner.h b/cpp/core/shuffle/Partitioner.h index 79a38be133f1..0adb82dffc74 100644 --- a/cpp/core/shuffle/Partitioner.h +++ b/cpp/core/shuffle/Partitioner.h @@ -27,7 +27,7 @@ namespace gluten { class Partitioner { public: static arrow::Result> - make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId); + make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId, bool sortBeforeRR); // Whether the first column is partition key. bool hasPid() const { diff --git a/cpp/core/shuffle/RoundRobinPartitioner.h b/cpp/core/shuffle/RoundRobinPartitioner.h index 74fb8dcef855..a01ff65c4575 100644 --- a/cpp/core/shuffle/RoundRobinPartitioner.h +++ b/cpp/core/shuffle/RoundRobinPartitioner.h @@ -23,8 +23,8 @@ namespace gluten { class RoundRobinPartitioner final : public Partitioner { public: - RoundRobinPartitioner(int32_t numPartitions, int32_t startPartitionId) - : Partitioner(numPartitions, false), pidSelection_(startPartitionId % numPartitions) {} + RoundRobinPartitioner(int32_t numPartitions, int32_t startPartitionId, bool hasPid) + : Partitioner(numPartitions, hasPid), pidSelection_(startPartitionId % numPartitions) {} arrow::Status compute( const int32_t* pidArr, diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.cc b/cpp/velox/shuffle/VeloxShuffleWriter.cc index b16acfc670d5..a2fda384991d 100644 --- a/cpp/velox/shuffle/VeloxShuffleWriter.cc +++ b/cpp/velox/shuffle/VeloxShuffleWriter.cc @@ -221,7 +221,8 @@ arrow::Status VeloxShuffleWriter::init() { VELOX_CHECK_LE(options_.bufferSize, 32 * 1024); ARROW_ASSIGN_OR_RAISE( - partitioner_, Partitioner::make(options_.partitioning, numPartitions_, options_.startPartitionId)); + partitioner_, + Partitioner::make(options_.partitioning, numPartitions_, options_.startPartitionId, options_.sortBeforeRR)); // pre-allocated buffer size for each partition, unit is row count // when partitioner is SinglePart, partial variables don`t need init diff --git a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala index 1310a60cadb0..d3c66192f42c 100644 --- a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala @@ -22,6 +22,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.BlockManager import org.apache.celeborn.client.ShuffleClient @@ -81,6 +82,9 @@ abstract class CelebornHashBasedColumnarShuffleWriter[K, V]( protected var partitionLengths: Array[Long] = _ + protected var sortBeforeRR: Boolean = + SQLConf.get.sortBeforeRepartition && dep.partitioner.numPartitions > 1 + @throws[IOException] final override def write(records: Iterator[Product2[K, V]]): Unit = { internalWrite(records) diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala index 172328d8119a..2991226d935e 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala @@ -125,7 +125,8 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( context.taskAttemptId(), GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId), "celeborn", - GlutenConfig.getConf.columnarShuffleReallocThreshold + GlutenConfig.getConf.columnarShuffleReallocThreshold, + sortBeforeRR ) } val startTime = System.nanoTime() diff --git a/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java b/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java index 6f49f6025e37..70350852ada5 100644 --- a/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java +++ b/gluten-data/src/main/java/io/glutenproject/vectorized/ShuffleWriterJniWrapper.java @@ -69,7 +69,8 @@ public long make( double reallocThreshold, long handle, long taskAttemptId, - int startPartitionId) { + int startPartitionId, + boolean sortBeforeRR) { return nativeMake( part.getShortName(), part.getNumPartitions(), @@ -91,7 +92,8 @@ public long make( startPartitionId, 0, null, - "local"); + "local", + sortBeforeRR); } /** @@ -116,7 +118,8 @@ public long makeForRSS( long taskAttemptId, int startPartitionId, String partitionWriterType, - double reallocThreshold) { + double reallocThreshold, + boolean sortBeforeRR) { return nativeMake( part.getShortName(), part.getNumPartitions(), @@ -138,7 +141,8 @@ public long makeForRSS( startPartitionId, pushBufferMaxSize, pusher, - partitionWriterType); + partitionWriterType, + sortBeforeRR); } public native long nativeMake( @@ -162,7 +166,8 @@ public native long nativeMake( int startPartitionId, int pushBufferMaxSize, Object pusher, - String partitionWriterType); + String partitionWriterType, + boolean sortBeforeRR); /** * Evict partition data. diff --git a/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 99710b3f668a..daa4e485b575 100644 --- a/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.SHUFFLE_COMPRESS import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.scheduler.MapStatus +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkDirectoryUtil, SparkResourceUtil, Utils} @@ -106,6 +107,9 @@ class ColumnarShuffleWriter[K, V]( private val taskContext: TaskContext = TaskContext.get() + private val sortBeforeRR: Boolean = + SQLConf.get.sortBeforeRepartition && dep.partitioner.numPartitions > 1 + private def availableOffHeapPerTask(): Long = { val perTask = SparkMemoryUtil.getCurrentAvailableOffHeapMemory / SparkResourceUtil.getTaskSlots(conf) @@ -176,7 +180,8 @@ class ColumnarShuffleWriter[K, V]( reallocThreshold, handle, taskContext.taskAttemptId(), - GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId) + GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId), + sortBeforeRR ) } val startTime = System.nanoTime()