Skip to content

Commit

Permalink
Reference count buffers and clean them up properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin authored and aarondav committed Oct 10, 2014
1 parent 2b44cf1 commit 064747b
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 71 deletions.
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}


/**
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
Expand Down Expand Up @@ -231,7 +233,12 @@ object SparkEnv extends Logging {

val shuffleMemoryManager = new ShuffleMemoryManager(conf)

val blockTransferService = new NioBlockTransferService(conf, securityManager)
// TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec.
val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) {
new NettyBlockTransferService(conf)
} else {
new NioBlockTransferService(conf, securityManager)
}

val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ import org.apache.spark.storage.StorageLevel
trait BlockDataManager {

/**
* Interface to get local block data.
*
* @return Some(buffer) if the block exists locally, and None if it doesn't.
* Interface to get local block data. Throws an exception if the block cannot be found or
* cannot be read successfully.
*/
def getBlockData(blockId: String): Option[ManagedBuffer]
def getBlockData(blockId: String): ManagedBuffer

/**
* Put the block locally, using the given storage level.
Expand Down
41 changes: 36 additions & 5 deletions core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils}
* This interface provides an immutable view for data in the form of bytes. The implementation
* should specify how the data is provided:
*
* - FileSegmentManagedBuffer: data backed by part of a file
* - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
* - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
* - [[FileSegmentManagedBuffer]]: data backed by part of a file
* - [[NioByteBufferManagedBuffer]]: data backed by a NIO ByteBuffer
* - [[NettyByteBufManagedBuffer]]: data backed by a Netty ByteBuf
*
* The concrete buffer implementation might be managed outside the JVM garbage collector.
* For example, in the case of [[NettyByteBufManagedBuffer]], the buffers are reference counted.
* In that case, if the buffer is going to be passed around to a different thread, retain/release
* should be called.
*/
abstract class ManagedBuffer {
// Note that all the methods are defined with parenthesis because their implementations can
Expand All @@ -59,6 +64,17 @@ abstract class ManagedBuffer {
*/
def inputStream(): InputStream

/**
* Increment the reference count by one if applicable.
*/
def retain(): this.type

/**
* If applicable, decrement the reference count by one and deallocates the buffer if the
* reference count reaches zero.
*/
def release(): this.type

/**
* Convert the buffer into an Netty object, used to write the data out.
*/
Expand Down Expand Up @@ -123,6 +139,10 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt
val fileChannel = new FileInputStream(file).getChannel
new DefaultFileRegion(fileChannel, offset, length)
}

// Content of file segments are not in-memory, so no need to reference count.
override def retain(): this.type = this
override def release(): this.type = this
}


Expand All @@ -138,6 +158,10 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
override def inputStream() = new ByteBufferInputStream(buf)

private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf)

// [[ByteBuffer]] is managed by the JVM garbage collector itself.
override def retain(): this.type = this
override def release(): this.type = this
}


Expand All @@ -154,6 +178,13 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {

private[network] override def convertToNetty(): AnyRef = buf

// TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
def release(): Unit = buf.release()
override def retain(): this.type = {
buf.retain()
this
}

override def release(): this.type = {
buf.release()
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ import org.apache.spark.util.Utils
private[netty]
class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging {

def this(sparkConf: SparkConf, dataProvider: BlockDataManager) = {
this(new NettyConfig(sparkConf), dataProvider)
}

def port: Int = _port

def hostName: String = _hostName
Expand Down Expand Up @@ -117,7 +113,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log

val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
_port = addr.getPort
_hostName = addr.getHostName
//_hostName = addr.getHostName
_hostName = Utils.localHostName()
}

/** Shutdown the server. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,28 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager)
logTrace(s"Received request from $client to fetch block $blockId")

// First make sure we can find the block. If not, send error back to the user.
var blockData: Option[ManagedBuffer] = null
var buf: ManagedBuffer = null
try {
blockData = dataProvider.getBlockData(blockId)
buf = dataProvider.getBlockData(blockId)
} catch {
case e: Exception =>
logError(s"Error opening block $blockId for request from $client", e)
respondWithError(e.getMessage)
return
}

blockData match {
case Some(buf) =>
ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener(
new ChannelFutureListener {
override def operationComplete(future: ChannelFuture): Unit = {
if (future.isSuccess) {
logTrace(s"Sent block $blockId (${buf.size} B) back to $client")
} else {
logError(
s"Error sending block $blockId to $client; closing connection", future.cause)
ctx.close()
}
}
ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener(
new ChannelFutureListener {
override def operationComplete(future: ChannelFuture): Unit = {
if (future.isSuccess) {
logTrace(s"Sent block $blockId (${buf.size} B) back to $client")
} else {
logError(
s"Error sending block $blockId to $client; closing connection", future.cause)
ctx.close()
}
)
case None =>
respondWithError("Block not found")
}
}
}
)
} // end of processBlockRequest
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
private def getBlock(blockId: String): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
val buffer = blockDataManager.getBlockData(blockId).orNull
// TODO(rxin): propagate error back to the client?
val buffer = blockDataManager.getBlockData(blockId)
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
if (buffer == null) null else buffer.nioByteBuffer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,17 @@ private[spark] class BlockManager(
*
* @return Some(buffer) if the block exists locally, and None if it doesn't.
*/
override def getBlockData(blockId: String): Option[ManagedBuffer] = {
override def getBlockData(blockId: String): ManagedBuffer = {
val bid = BlockId(blockId)
if (bid.isShuffle) {
Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])
} else {
val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
Some(new NioByteBufferManagedBuffer(buffer))
new NioByteBufferManagedBuffer(buffer)
} else {
None
throw new BlockNotFoundException(blockId)
}
}
}
Expand Down
Loading

0 comments on commit 064747b

Please sign in to comment.