diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index aba713cb4267a..373ce795a309e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -39,6 +40,7 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} + /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -231,7 +233,12 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new NioBlockTransferService(conf, securityManager) + // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. + val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { + new NettyBlockTransferService(conf) + } else { + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index e0e91724271c8..638e05f481f55 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -23,11 +23,10 @@ import org.apache.spark.storage.StorageLevel trait BlockDataManager { /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - def getBlockData(blockId: String): Option[ManagedBuffer] + def getBlockData(blockId: String): ManagedBuffer /** * Put the block locally, using the given storage level. diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 9c298132fcfba..7f364947dd930 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -35,9 +35,14 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils} * This interface provides an immutable view for data in the form of bytes. The implementation * should specify how the data is provided: * - * - FileSegmentManagedBuffer: data backed by part of a file - * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer - * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf + * - [[FileSegmentManagedBuffer]]: data backed by part of a file + * - [[NioByteBufferManagedBuffer]]: data backed by a NIO ByteBuffer + * - [[NettyByteBufManagedBuffer]]: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of [[NettyByteBufManagedBuffer]], the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. */ abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can @@ -59,6 +64,17 @@ abstract class ManagedBuffer { */ def inputStream(): InputStream + /** + * Increment the reference count by one if applicable. + */ + def retain(): this.type + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + def release(): this.type + /** * Convert the buffer into an Netty object, used to write the data out. */ @@ -123,6 +139,10 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt val fileChannel = new FileInputStream(file).getChannel new DefaultFileRegion(fileChannel, offset, length) } + + // Content of file segments are not in-memory, so no need to reference count. + override def retain(): this.type = this + override def release(): this.type = this } @@ -138,6 +158,10 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def inputStream() = new ByteBufferInputStream(buf) private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) + + // [[ByteBuffer]] is managed by the JVM garbage collector itself. + override def retain(): this.type = this + override def release(): this.type = this } @@ -154,6 +178,13 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { private[network] override def convertToNetty(): AnyRef = buf - // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. - def release(): Unit = buf.release() + override def retain(): this.type = { + buf.retain() + this + } + + override def release(): this.type = { + buf.release() + this + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 05443a74094d7..ceae31efac939 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -40,10 +40,6 @@ import org.apache.spark.util.Utils private[netty] class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { - def this(sparkConf: SparkConf, dataProvider: BlockDataManager) = { - this(new NettyConfig(sparkConf), dataProvider) - } - def port: Int = _port def hostName: String = _hostName @@ -117,7 +113,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] _port = addr.getPort - _hostName = addr.getHostName + //_hostName = addr.getHostName + _hostName = Utils.localHostName() } /** Shutdown the server. */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala index 739526a4fc6bc..c3b4d41829f4e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -66,9 +66,9 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) logTrace(s"Received request from $client to fetch block $blockId") // First make sure we can find the block. If not, send error back to the user. - var blockData: Option[ManagedBuffer] = null + var buf: ManagedBuffer = null try { - blockData = dataProvider.getBlockData(blockId) + buf = dataProvider.getBlockData(blockId) } catch { case e: Exception => logError(s"Error opening block $blockId for request from $client", e) @@ -76,23 +76,18 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) return } - blockData match { - case Some(buf) => - ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.size} B) back to $client") - } else { - logError( - s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } + ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.size} B) back to $client") + } else { + logError( + s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() } - ) - case None => - respondWithError("Block not found") - } + } + } + ) } // end of processBlockRequest } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..457ba106ced89 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -197,7 +197,8 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def getBlock(blockId: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId).orNull + // TODO(rxin): propagate error back to the client? + val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) if (buffer == null) null else buffer.nioByteBuffer() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3f5d06e1aeee7..9995a25c224fe 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -216,17 +216,17 @@ private[spark] class BlockManager( * * @return Some(buffer) if the block exists locally, and None if it doesn't. */ - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + override def getBlockData(blockId: String): ManagedBuffer = { val bid = BlockId(blockId) if (bid.isShuffle) { - Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) + shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - Some(new NioByteBufferManagedBuffer(buffer)) + new NioByteBufferManagedBuffer(buffer) } else { - None + throw new BlockNotFoundException(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 71b276b5f18e4..23f7d56895fe5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,10 +23,10 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue -import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -88,17 +88,49 @@ final class ShuffleBlockFetcherIterator( */ private[this] val results = new LinkedBlockingQueue[FetchResult] - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + private[this] var currentResult: FetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ private[this] val fetchRequests = new Queue[FetchRequest] - // Current bytes in flight from our requests + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @volatile private[this] var isZombie = false + initialize() + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + // Release the current buffer if necessary + if (currentResult != null && currentResult.buf != null) { + currentResult.buf.release() + } + + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result.buf.release() + } + } + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -110,13 +142,17 @@ final class ShuffleBlockFetcherIterator( blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => serializer.newInstance().deserializeStream( - blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator - )) - shuffleMetrics.remoteBytesRead += data.size - shuffleMetrics.remoteBlocksFetched += 1 + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + shuffleMetrics.remoteBytesRead += buf.size + shuffleMetrics.remoteBlocksFetched += 1 + } logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocks) { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() try { + val buf = blockManager.getBlockData(blockId.toString) shuffleMetrics.localBlocksFetched += 1 - results.put(new FetchResult( - id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) + buf.retain() + results.put(new FetchResult(blockId, 0, buf)) } catch { case e: Exception => + // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) + results.put(new FetchResult(blockId, -1, null)) return } } } private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(_ => cleanup()) + // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -229,7 +273,8 @@ final class ShuffleBlockFetcherIterator( override def next(): (BlockId, Option[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() - val result = results.take() + currentResult = results.take() + val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (!result.failed) { @@ -240,7 +285,21 @@ final class ShuffleBlockFetcherIterator( (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + + val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { + None + } else { + val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Some(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + result.buf.release() + })) + } + + (result.blockId, iteratorOpt) } } @@ -262,10 +321,10 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block. A failure is represented as size == -1. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. + * Note that this is NOT the exact bytes. -1 if failure is present. + * @param buf [[ManagedBuffer]] for the content. null is error. */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val blockId: BlockId, val size: Long, val buf: ManagedBuffer) { def failed: Boolean = size == -1 } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 178c60a048b9f..72d7c4b531099 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkConf import org.apache.spark.network._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} /** @@ -62,14 +62,14 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { fp.write(fileContent) fp.close() - server = new BlockServer(new SparkConf, new BlockDataManager { - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { + override def getBlockData(blockId: String): ManagedBuffer = { if (blockId == bufferBlockId) { - Some(new NioByteBufferManagedBuffer(buf)) + new NioByteBufferManagedBuffer(buf) } else if (blockId == fileBlockId) { - Some(new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25)) + new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) } else { - None + throw new BlockNotFoundException(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala index 6ae2d3b3faf91..1d13fd92e1f23 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -65,4 +65,8 @@ class TestManagedBuffer(len: Int) extends ManagedBuffer { } case _ => false } + + override def retain(): this.type = this + + override def release(): this.type = this }