Skip to content

Commit

Permalink
Relax assumptions on compressors and serializers when batching
Browse files Browse the repository at this point in the history
This commit introduces an intermediate layer of an input stream on the batch level.
This guards against interference from higher level streams (i.e. compression and
deserialization streams), especially pre-fetching, without specifically targeting
particular libraries (Kryo) and forcing shuffle spill compression to use LZF.
  • Loading branch information
andrewor14 committed Feb 4, 2014
1 parent 0386f42 commit 164489d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long

/**
* Number of bytes written so far
*/
def bytesWritten: Long
}

/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
Expand Down Expand Up @@ -183,7 +188,8 @@ private[spark] class DiskBlockObjectWriter(
// Only valid if called after close()
override def timeWriting() = _timeWriting

def bytesWritten: Long = {
// Only valid if called after commit()
override def bytesWritten: Long = {
lastValidPosition - initialPosition
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ import scala.collection.mutable.ArrayBuffer
import it.unimi.dsi.fastutil.io.FastBufferedInputStream

import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.io.LZFCompressionCodec
import org.apache.spark.serializer.{KryoDeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockObjectWriter}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManager}

/**
* An append-only map that spills sorted content to disk when there is insufficient space for it
Expand Down Expand Up @@ -84,12 +83,14 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
// Number of in-memory pairs inserted before tracking the map's shuffle memory usage
private val trackMemoryThreshold = 1000

// Size of object batches when reading/writing from serializers. Objects are written in
// batches, with each batch using its own serialization stream. This cuts down on the size
// of reference-tracking maps constructed when deserializing a stream.
//
// NOTE: Setting this too low can cause excess copying when serializing, since some serializers
// grow internal data structures by growing + copying every time the number of objects doubles.
/* Size of object batches when reading/writing from serializers.
*
* Objects are written in batches, with each batch using its own serialization stream. This
* cuts down on the size of reference-tracking maps constructed when deserializing a stream.
*
* NOTE: Setting this too low can cause excess copying when serializing, since some serializers
* grow internal data structures by growing + copying every time the number of objects doubles.
*/
private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)

// How many times we have spilled so far
Expand All @@ -100,7 +101,6 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private var _diskBytesSpilled = 0L

private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
private val comparator = new KCComparator[K, C]
private val ser = serializer.newInstance()

Expand Down Expand Up @@ -153,37 +153,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
var objectsWritten = 0

/* IMPORTANT NOTE: To avoid having to keep large object graphs in memory, this approach
* closes and re-opens serialization and compression streams within each file. This makes some
* assumptions about the way that serialization and compression streams work, specifically:
*
* 1) The serializer input streams do not pre-fetch data from the underlying stream.
*
* 2) Several compression streams can be opened, written to, and flushed on the write path
* while only one compression input stream is created on the read path
*
* In practice (1) is only true for Java, so we add a special fix below to make it work for
* Kryo. (2) is only true for LZF and not Snappy, so we coerce this to use LZF.
*
* To avoid making these assumptions we should create an intermediate stream that batches
* objects and sends an EOF to the higher layer streams to make sure they never prefetch data.
* This is a bit tricky because, within each segment, you'd need to track the total number
* of bytes written and then re-wind and write it at the beginning of the segment. This will
* most likely require using the file channel API.
*/
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]

val shouldCompress = blockManager.shouldCompress(blockId)
val compressionCodec = new LZFCompressionCodec(sparkConf)
def wrapForCompression(outputStream: OutputStream) = {
if (shouldCompress) compressionCodec.compressedOutputStream(outputStream) else outputStream
// Flush the disk writer's contents to disk, and update relevant variables
def flush() = {
writer.commit()
val bytesWritten = writer.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

def getNewWriter = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize,
wrapForCompression, syncWrites)

var writer = getNewWriter
var objectsWritten = 0
try {
val it = currentMap.destructiveSortedIterator(comparator)
while (it.hasNext) {
Expand All @@ -192,22 +176,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
objectsWritten += 1

if (objectsWritten == serializerBatchSize) {
writer.commit()
flush()
writer.close()
_diskBytesSpilled += writer.bytesWritten
writer = getNewWriter
objectsWritten = 0
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
}
}

if (objectsWritten > 0) writer.commit()
if (objectsWritten > 0) {
flush()
}
} finally {
// Partial failures cannot be tolerated; do not revert partial writes
writer.close()
_diskBytesSpilled += writer.bytesWritten
}

currentMap = new SizeTrackingAppendOnlyMap[K, C]
spilledMaps.append(new DiskMapIterator(file, blockId))
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))

// Reset the amount of shuffle memory used by this map in the global pool
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
Expand Down Expand Up @@ -252,8 +235,9 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
}

/**
* Fetch from the given iterator until a key of different hash is retrieved. In the
* event of key hash collisions, this ensures no pairs are hidden from being merged.
* Fetch from the given iterator until a key of different hash is retrieved.
*
* In the event of key hash collisions, this ensures no pairs are hidden from being merged.
* Assume the given iterator is in sorted order.
*/
def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
Expand Down Expand Up @@ -293,7 +277,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)

/**
* Select a key with the minimum hash, then combine all values with the same key from all input streams.
* Select a key with the minimum hash, then combine all values with the same key from all
* input streams
*/
override def next(): (K, C) = {
// Select a key from the StreamBuffer that holds the lowest key hash
Expand Down Expand Up @@ -355,51 +340,66 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
/**
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
private class DiskMapIterator(file: File, blockId: BlockId) extends Iterator[(K, C)] {
private class DiskMapIterator(file: File,
blockId: BlockId,
batchSizes: ArrayBuffer[Long]) extends Iterator[(K, C)] {
val fileStream = new FileInputStream(file)
val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)

val shouldCompress = blockManager.shouldCompress(blockId)
val compressionCodec = new LZFCompressionCodec(sparkConf)
val compressedStream =
if (shouldCompress) {
compressionCodec.compressedInputStream(bufferedStream)
} else {
bufferedStream
}
var deserializeStream = ser.deserializeStream(compressedStream)
var objectsRead = 0
// An intermediate stream that holds all the bytes from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
var batchStream = nextBatchStream(bufferedStream)

var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
var deserializeStream = ser.deserializeStream(compressedStream)
var nextItem: (K, C) = null
var eof = false

/**
* Construct a stream that contains all the bytes from the next batch
*/
def nextBatchStream(stream: InputStream): ByteArrayInputStream = {
var batchBytes = Array[Byte]()
if (batchSizes.length > 0) {
val batchSize = batchSizes.remove(0)

// Read batchSize number of bytes into batchBytes
while (batchBytes.length < batchSize) {
val numBytesToRead = Math.min(8192, batchSize - batchBytes.length).toInt
val bytesRead = new Array[Byte](numBytesToRead)
stream.read(bytesRead, 0, numBytesToRead)
batchBytes ++= bytesRead
}
} else {
// No more batches left
eof = true
}
new ByteArrayInputStream(batchBytes)
}

/**
* Return the next (K, C) pair from the deserialization stream.
*
* If the underlying batch stream is drained, construct a new stream for the next batch
* (if there is one) and stream from it. If there are no more batches left, return null.
*/
def readNextItem(): (K, C) = {
if (!eof) {
try {
if (objectsRead == serializerBatchSize) {
val newInputStream = deserializeStream match {
case stream: KryoDeserializationStream =>
// Kryo's serializer stores an internal buffer that pre-fetches from the underlying
// stream. We need to capture this buffer and feed it to the new serialization
// stream so that the bytes are not lost.
val kryoInput = stream.input
val remainingBytes = kryoInput.limit() - kryoInput.position()
val extraBuf = kryoInput.readBytes(remainingBytes)
new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream)
case _ => compressedStream
}
deserializeStream = ser.deserializeStream(newInputStream)
objectsRead = 0
}
objectsRead += 1
return deserializeStream.readObject().asInstanceOf[(K, C)]
} catch {
case e: EOFException =>
eof = true
try {
deserializeStream.readObject().asInstanceOf[(K, C)]
} catch {
// End of current batch
case e: EOFException =>
batchStream = nextBatchStream(bufferedStream)
if (!eof) {
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
deserializeStream = ser.deserializeStream(compressedStream)
readNextItem()
} else {
// No more batches left
cleanup()
}
null
}
}
null
}

override def hasNext: Boolean = {
Expand Down

0 comments on commit 164489d

Please sign in to comment.