diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala index a7836e4a13d1..86e1aa1ce4b0 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.ShuffleMode import java.io.IOException import java.util.Locale @@ -55,61 +56,16 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V]( private var splitResult: CHSplitResult = _ - private val nativeBufferSize: Int = GlutenConfig.getConf.shuffleWriterBufferSize - @throws[IOException] override def internalWrite(records: Iterator[Product2[K, V]]): Unit = { - if (!records.hasNext) { - handleEmptyIterator() - return - } - - if (nativeShuffleWriter == -1L) { - nativeShuffleWriter = jniWrapper.makeForRSS( - dep.nativePartitioning, - shuffleId, - mapId, - nativeBufferSize, - customizedCompressCodec, - GlutenConfig.getConf.chColumnarShuffleSpillThreshold, - CHBackendSettings.shuffleHashAlgorithm, - celebornPartitionPusher, - GlutenConfig.getConf.chColumnarThrowIfMemoryExceed, - GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict, - GlutenConfig.getConf.chColumnarForceExternalSortShuffle, - GlutenConfig.getConf.chColumnarForceMemorySortShuffle - ) - CHNativeMemoryAllocators.createSpillable( - "CelebornShuffleWriter", - new Spiller() { - override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { - if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { - return 0L - } - if (nativeShuffleWriter == -1L) { - throw new IllegalStateException( - "Fatal: spill() called before a celeborn shuffle writer " + - "is created. This behavior should be" + - "optimized by moving memory " + - "allocations from make() to split()") - } - logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data") - val spilled = jniWrapper.evict(nativeShuffleWriter) - logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data") - spilled - } - } - ) - } while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] if (cb.numRows == 0 || cb.numCols == 0) { logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") } else { + initShuffleWriter() val col = cb.column(0).asInstanceOf[CHColumnVector] - val block = col.getBlockAddress - jniWrapper - .split(nativeShuffleWriter, block) + jniWrapper.split(nativeShuffleWriter, col.getBlockAddress) dep.metrics("numInputRows").add(cb.numRows) dep.metrics("inputBatches").add(1) // This metric is important, AQE use it to decide if EliminateLimit @@ -117,6 +73,12 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V]( } } + // If all of the ColumnarBatch have empty rows, the nativeShuffleWriter still equals -1 + if (nativeShuffleWriter == -1L) { + handleEmptyIterator() + return + } + splitResult = jniWrapper.stop(nativeShuffleWriter) dep.metrics("splitTime").add(splitResult.getSplitTime) @@ -135,6 +97,43 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V]( mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getRawPartitionLengths, mapId) } + override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = { + nativeShuffleWriter = jniWrapper.makeForRSS( + dep.nativePartitioning, + shuffleId, + mapId, + nativeBufferSize, + customizedCompressCodec, + GlutenConfig.getConf.chColumnarShuffleSpillThreshold, + CHBackendSettings.shuffleHashAlgorithm, + celebornPartitionPusher, + GlutenConfig.getConf.chColumnarThrowIfMemoryExceed, + GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict, + GlutenConfig.getConf.chColumnarForceExternalSortShuffle, + GlutenConfig.getConf.chColumnarForceMemorySortShuffle + || ShuffleMode.SORT.name.equalsIgnoreCase(shuffleWriterType) + ) + CHNativeMemoryAllocators.createSpillable( + "CelebornShuffleWriter", + new Spiller() { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { + return 0L + } + if (nativeShuffleWriter == -1L) { + throw new IllegalStateException( + "Fatal: spill() called before a celeborn shuffle writer is created. " + + "This behavior should be optimized by moving memory allocations from make() to split()") + } + logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data") + val spilled = jniWrapper.evict(nativeShuffleWriter) + logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data") + spilled + } + } + ) + } + override def closeShuffleWriter(): Unit = { jniWrapper.close(nativeShuffleWriter) } diff --git a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala index efd891498131..dbc8933de5f8 100644 --- a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornHashBasedColumnarShuffleWriter.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.SHUFFLE_COMPRESS import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.BlockManager import org.apache.celeborn.client.ShuffleClient @@ -52,12 +53,23 @@ abstract class CelebornHashBasedColumnarShuffleWriter[K, V]( protected val mapId: Int = context.partitionId() + protected lazy val nativeBufferSize: Int = { + val bufferSize = GlutenConfig.getConf.shuffleWriterBufferSize + val maxBatchSize = GlutenConfig.getConf.maxBatchSize + if (bufferSize > maxBatchSize) { + logInfo( + s"${GlutenConfig.SHUFFLE_WRITER_BUFFER_SIZE.key} ($bufferSize) exceeds max " + + s" batch size. Limited to ${GlutenConfig.COLUMNAR_MAX_BATCH_SIZE.key} ($maxBatchSize).") + maxBatchSize + } else { + bufferSize + } + } + protected val clientPushBufferMaxSize: Int = celebornConf.clientPushBufferMaxSize protected val clientPushSortMemoryThreshold: Long = celebornConf.clientPushSortMemoryThreshold - protected val clientSortMemoryMaxSize: Long = celebornConf.clientPushSortMemoryThreshold - protected val shuffleWriterType: String = celebornConf.shuffleWriterMode.name.toLowerCase(Locale.ROOT) @@ -96,6 +108,10 @@ abstract class CelebornHashBasedColumnarShuffleWriter[K, V]( @throws[IOException] final override def write(records: Iterator[Product2[K, V]]): Unit = { + if (!records.hasNext) { + handleEmptyIterator() + return + } internalWrite(records) } @@ -122,10 +138,18 @@ abstract class CelebornHashBasedColumnarShuffleWriter[K, V]( } } + def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {} + def closeShuffleWriter(): Unit = {} def getPartitionLengths: Array[Long] = partitionLengths + def initShuffleWriter(columnarBatch: ColumnarBatch): Unit = { + if (nativeShuffleWriter == -1L) { + createShuffleWriter(columnarBatch) + } + } + def pushMergedDataToCeleborn(): Unit = { val pushMergedDataTime = System.nanoTime client.prepareForMergeData(shuffleId, mapId, context.attemptNumber()) diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala index 87b16c65bd09..e9b6eeb27eb1 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala @@ -55,25 +55,6 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( private var splitResult: SplitResult = _ - private lazy val nativeBufferSize = { - val bufferSize = GlutenConfig.getConf.shuffleWriterBufferSize - val maxBatchSize = GlutenConfig.getConf.maxBatchSize - if (bufferSize > maxBatchSize) { - logInfo( - s"${GlutenConfig.SHUFFLE_WRITER_BUFFER_SIZE.key} ($bufferSize) exceeds max " + - s" batch size. Limited to ${GlutenConfig.COLUMNAR_MAX_BATCH_SIZE.key} ($maxBatchSize).") - maxBatchSize - } else { - bufferSize - } - } - - private val memoryLimit: Long = if ("sort".equals(shuffleWriterType)) { - Math.min(clientSortMemoryMaxSize, clientPushBufferMaxSize * numPartitions) - } else { - availableOffHeapPerTask() - } - private def availableOffHeapPerTask(): Long = { val perTask = SparkMemoryUtil.getCurrentAvailableOffHeapMemory / SparkResourceUtil.getTaskSlots(conf) @@ -82,49 +63,13 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( @throws[IOException] override def internalWrite(records: Iterator[Product2[K, V]]): Unit = { - if (!records.hasNext) { - handleEmptyIterator() - return - } - while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] if (cb.numRows == 0 || cb.numCols == 0) { logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") } else { + initShuffleWriter(cb) val handle = ColumnarBatches.getNativeHandle(cb) - if (nativeShuffleWriter == -1L) { - nativeShuffleWriter = jniWrapper.makeForRSS( - dep.nativePartitioning, - nativeBufferSize, - customizedCompressionCodec, - compressionLevel, - bufferCompressThreshold, - GlutenConfig.getConf.columnarShuffleCompressionMode, - clientPushBufferMaxSize, - clientPushSortMemoryThreshold, - celebornPartitionPusher, - handle, - context.taskAttemptId(), - GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId), - "celeborn", - shuffleWriterType, - GlutenConfig.getConf.columnarShuffleReallocThreshold - ) - runtime.addSpiller(new Spiller() { - override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { - if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { - return 0L - } - logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data") - // fixme pass true when being called by self - val pushed = - jniWrapper.nativeEvict(nativeShuffleWriter, size, false) - logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data") - pushed - } - }) - } val startTime = System.nanoTime() jniWrapper.write(nativeShuffleWriter, cb.numRows, handle, availableOffHeapPerTask()) dep.metrics("splitTime").add(System.nanoTime() - startTime) @@ -135,8 +80,13 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( } } + // If all of the ColumnarBatch have empty rows, the nativeShuffleWriter still equals -1 + if (nativeShuffleWriter == -1L) { + handleEmptyIterator() + return + } + val startTime = System.nanoTime() - assert(nativeShuffleWriter != -1L) splitResult = jniWrapper.stop(nativeShuffleWriter) dep @@ -155,6 +105,38 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } + override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = { + nativeShuffleWriter = jniWrapper.makeForRSS( + dep.nativePartitioning, + nativeBufferSize, + customizedCompressionCodec, + compressionLevel, + bufferCompressThreshold, + GlutenConfig.getConf.columnarShuffleCompressionMode, + clientPushBufferMaxSize, + clientPushSortMemoryThreshold, + celebornPartitionPusher, + ColumnarBatches.getNativeHandle(columnarBatch), + context.taskAttemptId(), + GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId), + "celeborn", + shuffleWriterType, + GlutenConfig.getConf.columnarShuffleReallocThreshold + ) + runtime.addSpiller(new Spiller() { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { + return 0L + } + logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data") + // fixme pass true when being called by self + val pushed = jniWrapper.nativeEvict(nativeShuffleWriter, size, false) + logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data") + pushed + } + }) + } + override def closeShuffleWriter(): Unit = { jniWrapper.close(nativeShuffleWriter) }