Skip to content

Commit

Permalink
Added more documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin authored and aarondav committed Oct 10, 2014
1 parent 1760d32 commit 2b44cf1
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,68 +19,35 @@ package org.apache.spark.network.netty

import java.util.concurrent.TimeoutException

import io.netty.bootstrap.Bootstrap
import io.netty.buffer.PooledByteBufAllocator
import io.netty.channel.socket.SocketChannel
import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption}
import io.netty.channel.{ChannelFuture, ChannelFutureListener}

import org.apache.spark.Logging
import org.apache.spark.network.BlockFetchingListener


/**
* Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to
* instantiate this client.
* Client for [[NettyBlockTransferService]]. The connection to server must have been established
* using [[BlockClientFactory]] before instantiating this.
*
* The constructor blocks until a connection is successfully established.
* This class is used to make requests to the server , while [[BlockClientHandler]] is responsible
* for handling responses from the server.
*
* Concurrency: thread safe and can be called from multiple threads.
*
* @param cf the ChannelFuture for the connection.
* @param handler [[BlockClientHandler]] for handling outstanding requests.
*/
@throws[TimeoutException]
private[netty]
class BlockClient(factory: BlockClientFactory, hostname: String, port: Int)
extends Logging {

private val handler = new BlockClientHandler
private val encoder = new ClientRequestEncoder
private val decoder = new ServerResponseDecoder

/** Netty Bootstrap for creating the TCP connection. */
private val bootstrap: Bootstrap = {
val b = new Bootstrap
b.group(factory.workerGroup)
.channel(factory.socketChannelClass)
// Use pooled buffers to reduce temporary buffer allocation
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
.option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)

b.handler(new ChannelInitializer[SocketChannel] {
override def initChannel(ch: SocketChannel): Unit = {
ch.pipeline
.addLast("clientRequestEncoder", encoder)
.addLast("frameDecoder", ProtocolUtils.createFrameDecoder())
.addLast("serverResponseDecoder", decoder)
.addLast("handler", handler)
}
})
b
}
class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging {

/** Netty ChannelFuture for the connection. */
private val cf: ChannelFuture = bootstrap.connect(hostname, port)
if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
throw new TimeoutException(
s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
}
private[this] val serverAddr = cf.channel().remoteAddress().toString

/**
* Ask the remote server for a sequence of blocks, and execute the callback.
*
* Note that this is asynchronous and returns immediately. Upstream caller should throttle the
* rate of fetching; otherwise we could run out of memory.
* rate of fetching; otherwise we could run out of memory due to large outstanding fetches.
*
* @param blockIds sequence of block ids to fetch.
* @param listener callback to fire on fetch success / failure.
Expand All @@ -89,7 +56,7 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int)
var startTime: Long = 0
logTrace {
startTime = System.nanoTime
s"Sending request $blockIds to $hostname:$port"
s"Sending request $blockIds to $serverAddr"
}

blockIds.foreach { blockId =>
Expand All @@ -101,12 +68,12 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int)
if (future.isSuccess) {
logTrace {
val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
s"Sending request $blockIds to $serverAddr took $timeTaken ms"
}
} else {
// Fail all blocks.
val errorMsg =
s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}"
logError(errorMsg, future.cause)
blockIds.foreach { blockId =>
handler.removeRequest(blockId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark.network.netty

import java.util.concurrent.TimeoutException

import io.netty.bootstrap.Bootstrap
import io.netty.buffer.PooledByteBufAllocator
import io.netty.channel._
import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.oio.OioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.channel.socket.oio.OioSocketChannel
import io.netty.channel.{Channel, EventLoopGroup}

import org.apache.spark.SparkConf
import org.apache.spark.util.Utils
Expand All @@ -38,12 +43,16 @@ class BlockClientFactory(val conf: NettyConfig) {
def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))

/** A thread factory so the threads are named (for debugging). */
private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client")

/** The following two are instantiated by the [[init]] method, depending ioMode. */
private[netty] var socketChannelClass: Class[_ <: Channel] = _
private[netty] var workerGroup: EventLoopGroup = _

// The encoders are stateless and can be shared among multiple clients.
private[this] val encoder = new ClientRequestEncoder
private[this] val decoder = new ServerResponseDecoder

init()

/** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
Expand Down Expand Up @@ -78,7 +87,36 @@ class BlockClientFactory(val conf: NettyConfig) {
* Concurrency: This method is safe to call from multiple threads.
*/
def createClient(remoteHost: String, remotePort: Int): BlockClient = {
new BlockClient(this, remoteHost, remotePort)
val handler = new BlockClientHandler

val bootstrap = new Bootstrap
bootstrap.group(workerGroup)
.channel(socketChannelClass)
// Use pooled buffers to reduce temporary buffer allocation
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
.option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs)

bootstrap.handler(new ChannelInitializer[SocketChannel] {
override def initChannel(ch: SocketChannel): Unit = {
ch.pipeline
.addLast("clientRequestEncoder", encoder)
.addLast("frameDecoder", ProtocolUtils.createFrameDecoder())
.addLast("serverResponseDecoder", decoder)
.addLast("handler", handler)
}
})

// Connect to the remote server
val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort)
if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) {
throw new TimeoutException(
s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)")
}

new BlockClient(cf, handler)
}

def stop(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ import org.apache.spark.network.BlockFetchingListener


/**
* Handler that processes server responses.
* Handler that processes server responses, in response to requests issued from [[BlockClient]].
* It works by tracking the list of outstanding requests (and their callbacks).
*
* Concurrency: thread safe and can be called from multiple threads.
*/
private[netty]
class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging {

/** Tracks the list of outstanding requests and their listeners on success/failure. */
private val outstandingRequests = java.util.Collections.synchronizedMap {
private[this] val outstandingRequests = java.util.Collections.synchronizedMap {
new java.util.HashMap[String, BlockFetchingListener]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log
/** Initialize the server. */
private def init(): Unit = {
bootstrap = new ServerBootstrap
val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss")
val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker")

// Use only one thread to accept connections, and 2 * num_cores for worker.
def initNio(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,37 +28,50 @@ import org.apache.spark.Logging
import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer}


/** Messages from the client to the server. */
sealed trait ClientRequest {
def id: Byte
}

/**
* Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can
* correspond to multiple [[ServerResponse]]s.
*/
final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest {
override def id = 0
}

/**
* Request to upload a block to the server. Currently the server does not ack the upload request.
*/
final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest {
require(blockId.length <= Byte.MaxValue)
override def id = 1
}


/** Messages from server to client (usually in response to some [[ClientRequest]]. */
sealed trait ServerResponse {
def id: Byte
}

/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */
final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse {
require(blockId.length <= Byte.MaxValue)
override def id = 0
}

/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */
final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse {
require(blockId.length <= Byte.MaxValue)
override def id = 1
}


/**
* Encoder used by the client side to encode client-to-server responses.
* Encoder for [[ClientRequest]] used in client side.
*
* This encoder is stateless so it is safe to be shared by multiple threads.
*/
@Sharable
final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] {
Expand Down Expand Up @@ -109,6 +122,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest]

/**
* Decoder in the server side to decode client requests.
* This decoder is stateless so it is safe to be shared by multiple threads.
*
* This assumes the inbound messages have been processed by a frame decoder created by
* [[ProtocolUtils.createFrameDecoder()]].
Expand Down Expand Up @@ -138,6 +152,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] {

/**
* Encoder used by the server side to encode server-to-client responses.
* This encoder is stateless so it is safe to be shared by multiple threads.
*/
@Sharable
final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging {
Expand Down Expand Up @@ -190,6 +205,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse

/**
* Decoder in the client side to decode server responses.
* This decoder is stateless so it is safe to be shared by multiple threads.
*
* This assumes the inbound messages have been processed by a frame decoder created by
* [[ProtocolUtils.createFrameDecoder()]].
Expand Down Expand Up @@ -229,6 +245,7 @@ private[netty] object ProtocolUtils {
new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8)
}

// TODO(rxin): Make sure these work for all charsets.
def readBlockId(in: ByteBuf): String = {
val numBytesToRead = in.readByte().toInt
val bytes = new Array[Byte](numBytesToRead)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel


/**
* Test suite that makes sure the server and the client implementations share the same protocol.
* Test cases that create real clients and servers and connect.
*/
class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {

Expand Down Expand Up @@ -93,8 +93,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
/** A ByteBuf for file_block */
lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)

def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) =
{
def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = {
val client = clientFactory.createClient(server.hostName, server.port)
val sem = new Semaphore(0)
val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
Expand Down

0 comments on commit 2b44cf1

Please sign in to comment.