Skip to content

Commit

Permalink
Realize that bypass never buffers; proceed to delete tons of code
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 25, 2015
1 parent 6185ee2 commit b6cc1eb
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package org.apache.spark.shuffle.sort

import java.io.{File, FileInputStream, FileOutputStream}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManager, BlockObjectWriter}
Expand All @@ -44,129 +42,53 @@ private[spark] class BypassMergeSortShuffleWriter[K, V](
conf: SparkConf,
blockManager: BlockManager,
partitioner: Partitioner,
writeMetrics: ShuffleWriteMetrics,
serializer: Option[Serializer] = None)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, V]]
with SortShuffleSorter[K, V] {
extends Logging with SortShuffleSorter[K, V] {

private[this] val numPartitions = partitioner.numPartitions
private[this] val shouldPartition = numPartitions > 1
private def getPartition(key: K): Int = {
if (shouldPartition) partitioner.getPartition(key) else 0
}

private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)

private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()

/**
* Allocates a new buffer. Called in the constructor and after every spill.
*/
private def newBuffer: () => WritablePartitionedPairCollection[K, V] with SizeTracker = {
val useSerializedPairBuffer =
conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
ser.supportsRelocationOfSerializedObjects
if (useSerializedPairBuffer) {
val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
() => new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
} else {
() => new PartitionedPairBuffer[K, V]
}
}
private var buffer = newBuffer()

private var _diskBytesSpilled = 0L
def diskBytesSpilled: Long = _diskBytesSpilled

/**
* Information about a spilled file.
*
* @param file the file
* @param blockId the block id
* @param serializerBatchSizes sizes, in bytes, of "batches" written by the serializer as we
* periodically reset its stream
* @param elementsPerPartition the number of elements in each partition, used to efficiently
* kepe track of partitions when merging.
*/
private[this] case class SpilledFile(
file: File,
blockId: BlockId,
serializerBatchSizes: Array[Long],
elementsPerPartition: Array[Long])
private val spills = new ArrayBuffer[SpilledFile]

/** Array of file writers for each partition, used if we've spilled */
/** Array of file writers for each partition */
private var partitionWriters: Array[BlockObjectWriter] = _

/**
* Write metrics for spill. This is initialized when partitionWriters is created */
private var spillWriteMetrics: ShuffleWriteMetrics = _

def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
// SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
assert (partitionWriters == null)
if (records.hasNext) {
spill(
WritablePartitionedIterator.fromIterator(records.map { kv =>
((getPartition(kv._1), kv._1), kv._2.asInstanceOf[V])
})
)
}
}

/**
* Spill the current in-memory collection to disk if needed.
*
* @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
private def maybeSpillCollection(usingMap: Boolean): Unit = {
if (spillingEnabled && maybeSpill(buffer, buffer.estimateSize())) {
buffer = newBuffer()
}
}

/**
* Spill our in-memory collection to separate files, one for each partition, then clears the
* collection.
*/
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, V]): Unit = {
spill(collection.writablePartitionedIterator())
}

private def spill(iterator: WritablePartitionedIterator): Unit = {
// Create our file writers if we haven't done so yet
if (partitionWriters == null) {
spillWriteMetrics = new ShuffleWriteMetrics()
val openStartTime = System.nanoTime
partitionWriters = Array.fill(numPartitions) {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = blockManager.diskBlockManager.createTempShuffleBlock()
val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
spillWriteMetrics)
val writer =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics)
writer.open()
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
// included in the shuffle write time.
spillWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
}
writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)

// No need to sort stuff, just write each element out
while (iterator.hasNext) {
val partitionId = iterator.nextPartition()
iterator.writeNext(partitionWriters(partitionId))
while (records.hasNext) {
val record = records.next()
val key: K = record._1
partitionWriters(getPartition(key)).write(key, record._2)
}
}
}

/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* Write all the data added into this writer into a single file in the disk store. This is
* called by the SortShuffleWriter and can go through an efficient path of just concatenating
* binary files if we decided to avoid merge-sorting.
* the per-partition binary files.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
Expand All @@ -180,64 +102,35 @@ private[spark] class BypassMergeSortShuffleWriter[K, V](
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)

if (spills.isEmpty) {
// Case where we only have in-memory data
assert (partitionWriters == null)
assert (spillWriteMetrics == null)

val it = buffer.writablePartitionedIterator()
while (it.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics.shuffleWriteMetrics.get)
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
writer.commitAndClose()
val segment = writer.fileSegment()
lengths(partitionId) = segment.length
}
} else {
// Case where we have both in-memory and spilled data.
assert (partitionWriters != null)
assert (spillWriteMetrics != null)
// For simplicity, spill out the current in-memory collection so that everything is in files.
spill(buffer)
partitionWriters.foreach(_.commitAndClose())
val out = new FileOutputStream(outputFile, true)
val writeStartTime = System.nanoTime
Utils.tryWithSafeFinally {
for (i <- 0 until numPartitions) {
val in = new FileInputStream(partitionWriters(i).fileSegment().file)
Utils.tryWithSafeFinally {
lengths(i) = Utils.copyStream(in, out, closeStreams = false, transferToEnabled)
} {
in.close()
}
// TODO: handle case where partition writers is null (e.g. we haven't written any data).

partitionWriters.foreach(_.commitAndClose())
// Concatenate the per-partition files.
val out = new FileOutputStream(outputFile, true)
val writeStartTime = System.nanoTime
Utils.tryWithSafeFinally {
for (i <- 0 until numPartitions) {
val in = new FileInputStream(partitionWriters(i).fileSegment().file)
Utils.tryWithSafeFinally {
lengths(i) = Utils.copyStream(in, out, closeStreams = false, transferToEnabled)
} {
in.close()
}
} {
out.close()
context.taskMetrics.shuffleWriteMetrics.foreach { m =>
m.incShuffleWriteTime(System.nanoTime - writeStartTime)
if (blockManager.diskBlockManager.getFile(partitionWriters(i).blockId).delete()) {
logError("Unable to delete file for partition i. ")
}
}
}
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics.shuffleWriteMetrics.foreach { m =>
if (spillWriteMetrics != null) {
m.incShuffleBytesWritten(spillWriteMetrics.shuffleBytesWritten)
m.incShuffleWriteTime(spillWriteMetrics.shuffleWriteTime)
m.incShuffleRecordsWritten(spillWriteMetrics.shuffleRecordsWritten)
} {
out.close()
context.taskMetrics.shuffleWriteMetrics.foreach { m =>
m.incShuffleWriteTime(System.nanoTime - writeStartTime)
}
}

lengths
}

def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
if (partitionWriters != null) {
partitionWriters.foreach { w =>
w.revertPartialWritesAndClose()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ private[spark] class SortShuffleWriter[K, V, C](
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleWriter[K, V](conf, blockManager, dep.partitioner, dep.serializer)
new BypassMergeSortShuffleWriter[K, V](
conf, blockManager, dep.partitioner, writeMetrics, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V]
destructiveSortedIterator(comparator)
}

def writablePartitionedIterator(): WritablePartitionedIterator = {
WritablePartitionedIterator.fromIterator(super.iterator)
}

def insert(partition: Int, key: K, value: V): Unit = {
update((partition, key), value)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
iterator
}

override def writablePartitionedIterator(): WritablePartitionedIterator = {
WritablePartitionedIterator.fromIterator(iterator)
}

private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
var pos = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
sort(keyComparator)
writablePartitionedIterator
}

override def writablePartitionedIterator(): WritablePartitionedIterator = {
new WritablePartitionedIterator {
// current position in the meta buffer in ints
var pos = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
*/
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
}
val it = partitionedDestructiveSortedIterator(keyComparator)
new WritablePartitionedIterator {
var cur = if (it.hasNext) it.next() else null

/**
* Iterate through the data and write out the elements instead of returning them.
*/
def writablePartitionedIterator(): WritablePartitionedIterator
def writeNext(writer: BlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
}
}

private[spark] object WritablePartitionedPairCollection {
Expand Down Expand Up @@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator {

def nextPartition(): Int
}

private[spark] object WritablePartitionedIterator {
def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
new WritablePartitionedIterator {
var cur = if (it.hasNext) it.next() else null

def writeNext(writer: BlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
}
}

0 comments on commit b6cc1eb

Please sign in to comment.