Skip to content

Commit

Permalink
[GLUTEN-4875][VL]Support spark sql conf sortBeforeRepartition to avoi…
Browse files Browse the repository at this point in the history
…d stage partial retry casuing result mismatch
  • Loading branch information
zjuwangg committed Mar 7, 2024
1 parent f9daf49 commit 4c2e3d8
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, Round, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, Round, SortOrder, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec}
import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SortExec, SparkPlan, VeloxColumnarWriteFilesExec}
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BuildSideRelation
Expand Down Expand Up @@ -215,7 +215,22 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
TransformHints.tagNotTransformable(shuffle, validationResult)
shuffle.withNewChildren(newChild :: Nil)
}

case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && num > 1 =>
val hashExpr = new Murmur3Hash(newChild.output)
val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ newChild.output
val projectTransformer = ProjectExecTransformer(projectList, newChild)
val sortOrder = SortOrder(projectList.head, Ascending)
val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = false, projectTransformer)
val projectValidationResult = projectTransformer.doValidate()
val sortValidationResult = sortByHashCode.doValidate()
if (projectValidationResult.isValid && sortValidationResult.isValid) {
ColumnarShuffleExchangeExec(shuffle, sortByHashCode, sortByHashCode.output.drop(1))
} else {
TransformHints.tagNotTransformable(
shuffle,
if (projectValidationResult.isValid) sortValidationResult else projectValidationResult)
shuffle.withNewChildren(newChild :: Nil)
}
case _ =>
ColumnarShuffleExchangeExec(shuffle, newChild, null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,4 +1156,27 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
checkOperatorMatch[HashAggregateExecTransformer]
}
}

test("test roundrobine with sort") {
// scalastyle:off
runQueryAndCompare("SELECT /*+ REPARTITION(3) */ * FROM lineitem") {
/*
ColumnarExchange RoundRobinPartitioning(3), REPARTITION_BY_NUM, [l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31], [id=#131], [id=#131], [OUTPUT] List(l_orderkey:LongType, l_partkey:LongType, l_suppkey:LongType, l_linenumber:IntegerType, l_quantity:DecimalType(12,2), l_extendedprice:DecimalType(12,2), l_discount:DecimalType(12,2), l_tax:DecimalType(12,2), l_returnflag:StringType, l_linestatus:StringType, l_shipdate:DateType, l_commitdate:DateType, l_receiptdate:DateType, l_shipinstruct:StringType, l_shipmode:StringType, l_comment:StringType), [OUTPUT] List(l_orderkey:LongType, l_partkey:LongType, l_suppkey:LongType, l_linenumber:IntegerType, l_quantity:DecimalType(12,2), l_extendedprice:DecimalType(12,2), l_discount:DecimalType(12,2), l_tax:DecimalType(12,2), l_returnflag:StringType, l_linestatus:StringType, l_shipdate:DateType, l_commitdate:DateType, l_receiptdate:DateType, l_shipinstruct:StringType, l_shipmode:StringType, l_comment:StringType)
+- ^(2) SortExecTransformer [hash_partition_key#302 ASC NULLS FIRST], false, 0
+- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31, 42) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31]
+- ^(2) BatchScanExecTransformer[l_orderkey#16L, l_partkey#17L, l_suppkey#18L, l_linenumber#19, l_quantity#20, l_extendedprice#21, l_discount#22, l_tax#23, l_returnflag#24, l_linestatus#25, l_shipdate#26, l_commitdate#27, l_receiptdate#28, l_shipinstruct#29, l_shipmode#30, l_comment#31] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[file:/home/wanggang.terry/gluten/backends-velox/target/scala-2.12/test..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<l_orderkey:bigint,l_partkey:bigint,l_suppkey:bigint,l_linenumber:int,l_quantity:decimal(12..., PushedFilters: [] RuntimeFilters: []
*/
checkOperatorMatch[SortExecTransformer]
}
// scalastyle:on

withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") {
runQueryAndCompare("SELECT /*+ REPARTITION(3) */ * FROM lineitem") {
df =>
{
assert(getExecutedPlan(df).count(_.isInstanceOf[SortExecTransformer]) == 0)
}
}
}
}
}
4 changes: 3 additions & 1 deletion cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,8 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
jint startPartitionId,
jint pushBufferMaxSize,
jobject partitionPusher,
jstring partitionWriterTypeJstr) {
jstring partitionWriterTypeJstr,
jboolean sortBeforeRR) {
JNI_METHOD_START
auto ctx = gluten::getRuntime(env, wrapper);
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);
Expand All @@ -844,6 +845,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
.partitioning = gluten::toPartitioning(jStringToCString(env, partitioningNameJstr)),
.taskAttemptId = (int64_t)taskAttemptId,
.startPartitionId = startPartitionId,
.sortBeforeRepartition = sortBeforeRR,
};

jclass cls = env->FindClass("java/lang/Thread");
Expand Down
1 change: 1 addition & 0 deletions cpp/core/shuffle/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct ShuffleWriterOptions {
int64_t taskAttemptId = -1;
int32_t startPartitionId = 0;
int64_t threadId = -1;
bool sortBeforeRR = true;
};

struct PartitionWriterOptions {
Expand Down
4 changes: 2 additions & 2 deletions cpp/core/shuffle/Partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
namespace gluten {

arrow::Result<std::shared_ptr<Partitioner>>
Partitioner::make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId) {
Partitioner::make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId, bool sortBeforeRR)) {
switch (partitioning) {
case Partitioning::kHash:
return std::make_shared<HashPartitioner>(numPartitions);
case Partitioning::kRoundRobin:
return std::make_shared<RoundRobinPartitioner>(numPartitions, startPartitionId);
return std::make_shared<RoundRobinPartitioner>(numPartitions, startPartitionId, sortBeforeRR);
case Partitioning::kSingle:
return std::make_shared<SinglePartitioner>();
case Partitioning::kRange:
Expand Down
2 changes: 1 addition & 1 deletion cpp/core/shuffle/Partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace gluten {
class Partitioner {
public:
static arrow::Result<std::shared_ptr<Partitioner>>
make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId);
make(Partitioning partitioning, int32_t numPartitions, int32_t startPartitionId, bool sortBeforeRR);

// Whether the first column is partition key.
bool hasPid() const {
Expand Down
4 changes: 2 additions & 2 deletions cpp/core/shuffle/RoundRobinPartitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace gluten {

class RoundRobinPartitioner final : public Partitioner {
public:
RoundRobinPartitioner(int32_t numPartitions, int32_t startPartitionId)
: Partitioner(numPartitions, false), pidSelection_(startPartitionId % numPartitions) {}
RoundRobinPartitioner(int32_t numPartitions, int32_t startPartitionId, bool hasPid)
: Partitioner(numPartitions, hasPid), pidSelection_(startPartitionId % numPartitions) {}

arrow::Status compute(
const int32_t* pidArr,
Expand Down
3 changes: 2 additions & 1 deletion cpp/velox/shuffle/VeloxShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ arrow::Status VeloxShuffleWriter::init() {
VELOX_CHECK_LE(options_.bufferSize, 32 * 1024);

ARROW_ASSIGN_OR_RAISE(
partitioner_, Partitioner::make(options_.partitioning, numPartitions_, options_.startPartitionId));
partitioner_,
Partitioner::make(options_.partitioning, numPartitions_, options_.startPartitionId, options_.sortBeforeRR));

// pre-allocated buffer size for each partition, unit is row count
// when partitioner is SinglePart, partial variables don`t need init
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.BlockManager

import org.apache.celeborn.client.ShuffleClient
Expand Down Expand Up @@ -81,6 +82,9 @@ abstract class CelebornHashBasedColumnarShuffleWriter[K, V](

protected var partitionLengths: Array[Long] = _

protected var sortBeforeRR: Boolean =
SQLConf.get.sortBeforeRepartition && dep.partitioner.numPartitions > 1

@throws[IOException]
final override def write(records: Iterator[Product2[K, V]]): Unit = {
internalWrite(records)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
context.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId),
"celeborn",
GlutenConfig.getConf.columnarShuffleReallocThreshold
GlutenConfig.getConf.columnarShuffleReallocThreshold,
sortBeforeRR
)
}
val startTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public long make(
double reallocThreshold,
long handle,
long taskAttemptId,
int startPartitionId) {
int startPartitionId,
boolean sortBeforeRR) {
return nativeMake(
part.getShortName(),
part.getNumPartitions(),
Expand All @@ -91,7 +92,8 @@ public long make(
startPartitionId,
0,
null,
"local");
"local",
sortBeforeRR);
}

/**
Expand All @@ -116,7 +118,8 @@ public long makeForRSS(
long taskAttemptId,
int startPartitionId,
String partitionWriterType,
double reallocThreshold) {
double reallocThreshold,
boolean sortBeforeRR) {
return nativeMake(
part.getShortName(),
part.getNumPartitions(),
Expand All @@ -138,7 +141,8 @@ public long makeForRSS(
startPartitionId,
pushBufferMaxSize,
pusher,
partitionWriterType);
partitionWriterType,
sortBeforeRR);
}

public native long nativeMake(
Expand All @@ -162,7 +166,8 @@ public native long nativeMake(
int startPartitionId,
int pushBufferMaxSize,
Object pusher,
String partitionWriterType);
String partitionWriterType,
boolean sortBeforeRR);

/**
* Evict partition data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SHUFFLE_COMPRESS
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkDirectoryUtil, SparkResourceUtil, Utils}

Expand Down Expand Up @@ -106,6 +107,9 @@ class ColumnarShuffleWriter[K, V](

private val taskContext: TaskContext = TaskContext.get()

private val sortBeforeRR: Boolean =
SQLConf.get.sortBeforeRepartition && dep.partitioner.numPartitions > 1

private def availableOffHeapPerTask(): Long = {
val perTask =
SparkMemoryUtil.getCurrentAvailableOffHeapMemory / SparkResourceUtil.getTaskSlots(conf)
Expand Down Expand Up @@ -176,7 +180,8 @@ class ColumnarShuffleWriter[K, V](
reallocThreshold,
handle,
taskContext.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId)
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId),
sortBeforeRR
)
}
val startTime = System.nanoTime()
Expand Down

0 comments on commit 4c2e3d8

Please sign in to comment.