From 2b44cf1b7547919bbe7386e954fe2f56be046790 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 14:36:31 -0700 Subject: [PATCH] Added more documentation. --- .../spark/network/netty/BlockClient.scala | 61 +++++-------------- .../network/netty/BlockClientFactory.scala | 44 ++++++++++++- .../network/netty/BlockClientHandler.scala | 5 +- .../spark/network/netty/BlockServer.scala | 4 +- .../apache/spark/network/netty/protocol.scala | 19 +++++- .../netty/ServerClientIntegrationSuite.scala | 5 +- 6 files changed, 80 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 95af2565bcc39..9333fefa92957 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -19,68 +19,35 @@ package org.apache.spark.network.netty import java.util.concurrent.TimeoutException -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption} +import io.netty.channel.{ChannelFuture, ChannelFutureListener} import org.apache.spark.Logging import org.apache.spark.network.BlockFetchingListener /** - * Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to - * instantiate this client. + * Client for [[NettyBlockTransferService]]. The connection to server must have been established + * using [[BlockClientFactory]] before instantiating this. * - * The constructor blocks until a connection is successfully established. + * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible + * for handling responses from the server. * * Concurrency: thread safe and can be called from multiple threads. + * + * @param cf the ChannelFuture for the connection. + * @param handler [[BlockClientHandler]] for handling outstanding requests. */ @throws[TimeoutException] private[netty] -class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) - extends Logging { - - private val handler = new BlockClientHandler - private val encoder = new ClientRequestEncoder - private val decoder = new ServerResponseDecoder - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler) - } - }) - b - } +class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging { - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } + private[this] val serverAddr = cf.channel().remoteAddress().toString /** * Ask the remote server for a sequence of blocks, and execute the callback. * * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. + * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. * * @param blockIds sequence of block ids to fetch. * @param listener callback to fire on fetch success / failure. @@ -89,7 +56,7 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) var startTime: Long = 0 logTrace { startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" + s"Sending request $blockIds to $serverAddr" } blockIds.foreach { blockId => @@ -101,12 +68,12 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) if (future.isSuccess) { logTrace { val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" + s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { // Fail all blocks. val errorMsg = - s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" + s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => handler.removeRequest(blockId) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 0777275cd4fe3..f05f1419ded14 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,12 +17,17 @@ package org.apache.spark.network.netty +import java.util.concurrent.TimeoutException + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel._ import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{Channel, EventLoopGroup} import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -38,12 +43,16 @@ class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client") /** The following two are instantiated by the [[init]] method, depending ioMode. */ private[netty] var socketChannelClass: Class[_ <: Channel] = _ private[netty] var workerGroup: EventLoopGroup = _ + // The encoders are stateless and can be shared among multiple clients. + private[this] val encoder = new ClientRequestEncoder + private[this] val decoder = new ServerResponseDecoder + init() /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ @@ -78,7 +87,36 @@ class BlockClientFactory(val conf: NettyConfig) { * Concurrency: This method is safe to call from multiple threads. */ def createClient(remoteHost: String, remotePort: Int): BlockClient = { - new BlockClient(this, remoteHost, remotePort) + val handler = new BlockClientHandler + + val bootstrap = new Bootstrap + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Use pooled buffers to reduce temporary buffer allocation + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) + + bootstrap.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler) + } + }) + + // Connect to the remote server + val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) + if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") + } + + new BlockClient(cf, handler) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index b41c831f3d7e5..2a474cd71eab8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -24,7 +24,8 @@ import org.apache.spark.network.BlockFetchingListener /** - * Handler that processes server responses. + * Handler that processes server responses, in response to requests issued from [[BlockClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). * * Concurrency: thread safe and can be called from multiple threads. */ @@ -32,7 +33,7 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { + private[this] val outstandingRequests = java.util.Collections.synchronizedMap { new java.util.HashMap[String, BlockFetchingListener] } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 3433c5763ab3c..05443a74094d7 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -58,8 +58,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log /** Initialize the server. */ private def init(): Unit = { bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss") + val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker") // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index 0159eca1d3b41..ac6a4d00f654f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -28,29 +28,40 @@ import org.apache.spark.Logging import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +/** Messages from the client to the server. */ sealed trait ClientRequest { def id: Byte } +/** + * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can + * correspond to multiple [[ServerResponse]]s. + */ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { override def id = 0 } +/** + * Request to upload a block to the server. Currently the server does not ack the upload request. + */ final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 } +/** Messages from server to client (usually in response to some [[ClientRequest]]. */ sealed trait ServerResponse { def id: Byte } +/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 0 } +/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -58,7 +69,9 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve /** - * Encoder used by the client side to encode client-to-server responses. + * Encoder for [[ClientRequest]] used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { @@ -109,6 +122,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] /** * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. * * This assumes the inbound messages have been processed by a frame decoder created by * [[ProtocolUtils.createFrameDecoder()]]. @@ -138,6 +152,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { /** * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { @@ -190,6 +205,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse /** * Decoder in the client side to decode server responses. + * This decoder is stateless so it is safe to be shared by multiple threads. * * This assumes the inbound messages have been processed by a frame decoder created by * [[ProtocolUtils.createFrameDecoder()]]. @@ -229,6 +245,7 @@ private[netty] object ProtocolUtils { new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) } + // TODO(rxin): Make sure these work for all charsets. def readBlockId(in: ByteBuf): String = { val numBytesToRead = in.readByte().toInt val bytes = new Array[Byte](numBytesToRead) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index a468764fb1848..178c60a048b9f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel /** -* Test suite that makes sure the server and the client implementations share the same protocol. +* Test cases that create real clients and servers and connect. */ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { @@ -93,8 +93,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { /** A ByteBuf for file_block */ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = - { + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { val client = clientFactory.createClient(server.hostName, server.port) val sem = new Semaphore(0) val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])