Skip to content

Commit

Permalink
Code review feedback:
Browse files Browse the repository at this point in the history
1. Rename package name from cm to nio.

2. Refined BlockTransferService and ManagedBuffer interfaces.
  • Loading branch information
rxin committed Sep 3, 2014
1 parent 2c6b1e1 commit 8a1046e
Show file tree
Hide file tree
Showing 26 changed files with 141 additions and 127 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.cm.CMBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
Expand Down Expand Up @@ -226,7 +226,7 @@ object SparkEnv extends Logging {

val shuffleMemoryManager = new ShuffleMemoryManager(conf)

val blockTransferService = new CMBlockTransferService(conf, securityManager)
val blockTransferService = new NioBlockTransferService(conf, securityManager)

val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ trait BlockFetchingListener extends EventListener {
def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit

/**
* Called upon failures.
* Called upon failures. For each failure, this is called only once (i.e. not once per block).
*/
def onBlockFetchFailure(exception: Throwable): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.network

import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.storage.StorageLevel


Expand Down Expand Up @@ -48,9 +51,11 @@ abstract class BlockTransferService {
* available only after [[init]] is invoked.
*
* Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
* while [[BlockFetchingListener.onBlockFetchSuccess]] is called once per failure.
* while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
*
* This takes a sequence so the implementation can batch requests.
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
def fetchBlocks(
hostName: String,
Expand All @@ -59,12 +64,21 @@ abstract class BlockTransferService {
listener: BlockFetchingListener): Unit

/**
* Fetch a single block from a remote node, synchronously,
* available only after [[init]] is invoked.
* Upload a single block to a remote node, available only after [[init]] is invoked.
*/
def fetchBlock(hostName: String, port: Int, blockId: String): ManagedBuffer = {
// TODO(rxin): Add timeout?
def uploadBlock(
hostname: String,
port: Int,
blockId: String,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]

/**
* A special case of [[fetchBlocks]], since it only fetches on block and is blocking.
*
* It is also only available after [[init]] is invoked.
*/
def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
val lock = new Object
@volatile var result: Either[ManagedBuffer, Throwable] = null
Expand Down Expand Up @@ -103,12 +117,15 @@ abstract class BlockTransferService {
/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
*
* This call blocks until the upload completes, or throws an exception upon failures.
* This method is similar to [[uploadBlock]], except this one blocks the thread
* until the upload finishes.
*/
def uploadBlock(
def uploadBlockSync(
hostname: String,
port: Int,
blockId: String,
blockData: ManagedBuffer,
level: StorageLevel): Unit
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
}
}
73 changes: 39 additions & 34 deletions core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,85 +17,90 @@

package org.apache.spark.network

import java.io.{RandomAccessFile, File, FileInputStream, InputStream}
import java.io.{FileInputStream, RandomAccessFile, File, InputStream}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode

import io.netty.buffer.{ByteBufInputStream, ByteBuf, Unpooled}
import io.netty.channel.DefaultFileRegion
import io.netty.buffer.{ByteBufInputStream, ByteBuf}

import org.apache.spark.storage.FileSegment
import org.apache.spark.util.ByteBufferInputStream


/**
* Provides a buffer abstraction that allows pooling and reuse.
* This interface provides an immutable view for data in the form of bytes. The implementation
* should specify how the data is provided:
*
* - FileSegmentManagedBuffer: data backed by part of a file
* - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
* - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
*/
abstract class ManagedBuffer {
sealed abstract class ManagedBuffer {
// Note that all the methods are defined with parenthesis because their implementations can
// have side effects (io operations).

def byteBuffer(): ByteBuffer

def fileSegment(): Option[FileSegment] = None

def inputStream(): InputStream = throw new UnsupportedOperationException

def release(): Unit = throw new UnsupportedOperationException

/** Number of bytes of the data. */
def size: Long

private[network] def toNetty(): AnyRef
/**
* Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
* returned ByteBuffer should not affect the content of this buffer.
*/
def nioByteBuffer(): ByteBuffer

/**
* Exposes this buffer's data as an InputStream. The underlying implementation does not
* necessarily check for the length of bytes read, so the caller is responsible for making sure
* it does not go over the limit.
*/
def inputStream(): InputStream
}


/**
* A ManagedBuffer backed by a segment in a file.
* A [[ManagedBuffer]] backed by a segment in a file
*/
final class FileSegmentManagedBuffer(file: File, offset: Long, length: Long)
final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
extends ManagedBuffer {

override def size: Long = length

override def byteBuffer(): ByteBuffer = {
override def nioByteBuffer(): ByteBuffer = {
val channel = new RandomAccessFile(file, "r").getChannel
channel.map(MapMode.READ_ONLY, offset, length)
}

override private[network] def toNetty(): AnyRef = {
val fileChannel = new FileInputStream(file).getChannel
new DefaultFileRegion(fileChannel, offset, length)
override def inputStream(): InputStream = {
val is = new FileInputStream(file)
is.skip(offset)
is
}
}


/**
* A ManagedBuffer backed by [[java.nio.ByteBuffer]].
* A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
*/
final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {

override def byteBuffer() = buf

override def inputStream() = new ByteBufferInputStream(buf)

override def size: Long = buf.remaining()

override private[network] def toNetty(): AnyRef = Unpooled.wrappedBuffer(buf)
override def nioByteBuffer() = buf

override def inputStream() = new ByteBufferInputStream(buf)
}


/**
* A ManagedBuffer backed by a Netty [[ByteBuf]].
* A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
*/
final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {

override def byteBuffer() = buf.nioBuffer()

override def inputStream() = new ByteBufInputStream(buf)
override def size: Long = buf.readableBytes()

override def release(): Unit = buf.release()
override def nioByteBuffer() = buf.nioBuffer()

override def size: Long = buf.readableBytes()
override def inputStream() = new ByteBufInputStream(buf)

override private[network] def toNetty(): AnyRef = buf
// TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
def release(): Unit = buf.release()
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
* limitations under the License.
*/

package org.apache.spark.network.cm
package org.apache.spark.network.nio

import java.nio.ByteBuffer

import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}

import scala.collection.mutable.{ArrayBuffer, StringBuilder}

// private[spark] because we need to register them in Kryo
private[spark] case class GetBlock(id: BlockId)
private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)

private[spark] class BlockMessage() {
private[nio] class BlockMessage() {
// Un-initialized: typ = 0
// GetBlock: typ = 1
// GotBlock: typ = 2
Expand Down Expand Up @@ -158,7 +159,7 @@ private[spark] class BlockMessage() {
}
}

private[spark] object BlockMessage {
private[nio] object BlockMessage {
val TYPE_NON_INITIALIZED: Int = 0
val TYPE_GET_BLOCK: Int = 1
val TYPE_GOT_BLOCK: Int = 2
Expand Down Expand Up @@ -193,16 +194,4 @@ private[spark] object BlockMessage {
newBlockMessage.set(putBlock)
newBlockMessage
}

def main(args: Array[String]) {
val B = new BlockMessage()
val blockId = TestBlockId("ABC")
B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)

println(B.getId + " " + B.getLevel)
println(C.getId + " " + C.getLevel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.network.cm
package org.apache.spark.network.nio

import java.nio.ByteBuffer

Expand All @@ -24,7 +24,7 @@ import org.apache.spark.storage.{StorageLevel, TestBlockId}

import scala.collection.mutable.ArrayBuffer

private[spark]
private[nio]
class BlockMessageArray(var blockMessages: Seq[BlockMessage])
extends Seq[BlockMessage] with Logging {

Expand Down Expand Up @@ -102,7 +102,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
}
}

private[spark] object BlockMessageArray {
private[nio] object BlockMessageArray {

def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
val newBlockMessageArray = new BlockMessageArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
* limitations under the License.
*/

package org.apache.spark.network.cm
package org.apache.spark.network.nio

import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.storage.BlockManager

import scala.collection.mutable.ArrayBuffer

private[spark]
private[nio]
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.network.cm
package org.apache.spark.network.nio

import java.net._
import java.nio._
Expand All @@ -25,7 +25,7 @@ import org.apache.spark._

import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}

private[spark]
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
extends Logging {
Expand Down Expand Up @@ -190,7 +190,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
}


private[spark]
private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId, id_ : ConnectionId)
extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
* limitations under the License.
*/

package org.apache.spark.network.cm
package org.apache.spark.network.nio

private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
}

private[spark] object ConnectionId {
private[nio] object ConnectionId {

def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
val res = connectionIdString.split("_").map(_.trim())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.network.cm
package org.apache.spark.network.nio

import java.io.IOException
import java.net._
Expand All @@ -26,15 +26,16 @@ import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
import java.util.{Timer, TimerTask}

import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps

private[spark] class ConnectionManager(
import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}


private[nio] class ConnectionManager(
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
Expand Down
Loading

0 comments on commit 8a1046e

Please sign in to comment.