From 939f276fe42feb3c333e39f371e9c6400fe22ddc Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 12 Oct 2014 16:45:55 -0700 Subject: [PATCH] Attempt to make comm. bidirectional --- .../spark/network/BlockFetchingListener.scala | 4 +- .../spark/network/BlockTransferService.scala | 15 ++- .../network/netty/NettyBlockFetcher.scala | 6 +- .../network/netty/NettyBlockRpcServer.scala | 33 ++++-- .../netty/NettyBlockTransferService.scala | 56 ++++++--- .../network/nio/NioBlockTransferService.scala | 4 +- .../apache/spark/serializer/Serializer.scala | 51 +++++++- .../apache/spark/storage/BlockManager.scala | 6 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../org/apache/spark/ShuffleNettySuite.scala | 4 +- .../apache/spark/network/SluiceContext.java | 111 ++++++++++++++++++ .../spark/network/client/SluiceClient.java | 50 ++++---- .../network/client/SluiceClientFactory.java | 55 +++++---- ...andler.java => SluiceResponseHandler.java} | 55 +++++---- .../ClientRequest.java => Message.java} | 26 ++-- .../protocol/request/ChunkFetchRequest.java | 4 +- .../request/ClientRequestDecoder.java | 57 --------- .../request/ClientRequestEncoder.java | 46 -------- .../protocol/request/RequestMessage.java | 25 ++++ .../network/protocol/request/RpcRequest.java | 6 +- .../protocol/response/ChunkFetchFailure.java | 7 +- .../protocol/response/ChunkFetchSuccess.java | 4 +- ...sponseDecoder.java => MessageDecoder.java} | 22 +++- ...sponseEncoder.java => MessageEncoder.java} | 16 ++- .../protocol/response/ResponseMessage.java | 25 ++++ .../network/protocol/response/RpcFailure.java | 7 +- .../protocol/response/RpcResponse.java | 2 +- .../protocol/response/ServerResponse.java | 63 ---------- .../network/server/DefaultStreamManager.java | 14 ++- .../spark/network/server/MessageHandler.java | 36 ++++++ .../spark/network/server/RpcHandler.java | 9 +- .../network/server/SluiceChannelHandler.java | 88 ++++++++++++++ ...Handler.java => SluiceRequestHandler.java} | 71 ++++++----- .../spark/network/server/SluiceServer.java | 26 ++-- .../spark/network/server/StreamManager.java | 2 +- .../org/apache/spark/network/util/IOMode.java | 2 +- .../apache/spark/network/util/JavaUtils.java | 14 ++- .../apache/spark/network/util/NettyUtils.java | 21 ++-- .../spark/network/IntegrationSuite.java | 22 ++-- .../apache/spark/network/NoOpRpcHandler.java | 3 +- .../apache/spark/network/ProtocolSuite.java | 21 ++-- .../network/SluiceClientFactorySuite.java | 12 +- .../network/SluiceClientHandlerSuite.java | 26 ++-- 43 files changed, 702 insertions(+), 427 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/SluiceContext.java rename network/common/src/main/java/org/apache/spark/network/client/{SluiceClientHandler.java => SluiceResponseHandler.java} (75%) rename network/common/src/main/java/org/apache/spark/network/protocol/{request/ClientRequest.java => Message.java} (67%) delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java rename network/common/src/main/java/org/apache/spark/network/protocol/response/{ServerResponseDecoder.java => MessageDecoder.java} (70%) rename network/common/src/main/java/org/apache/spark/network/protocol/response/{ServerResponseEncoder.java => MessageEncoder.java} (78%) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java delete mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java rename network/common/src/main/java/org/apache/spark/network/server/{SluiceServerHandler.java => SluiceRequestHandler.java} (65%) diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index e35fdb4e95899..645793fde806d 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -29,7 +29,9 @@ private[spark] trait BlockFetchingListener extends EventListener { /** - * Called once per successfully fetched block. + * Called once per successfully fetched block. After this call returns, data will be released + * automatically. If the data will be passed to another thread, the receiver should retain() + * and release() the buffer on their own, or copy the data to a new buffer. */ def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 8287a0fc81cfe..b083f465334fe 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,14 +18,14 @@ package org.apache.spark.network import java.io.Closeable - -import org.apache.spark.network.buffer.ManagedBuffer +import java.nio.ByteBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel +import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] @@ -72,7 +72,7 @@ abstract class BlockTransferService extends Closeable with Logging { def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -94,7 +94,10 @@ abstract class BlockTransferService extends Closeable with Logging { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - result = Left(data) + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result = Left(new NioManagedBuffer(ret)) lock.notify() } } @@ -126,7 +129,7 @@ abstract class BlockTransferService extends Closeable with Logging { def uploadBlockSync( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala index aefd8a6335b2a..a03e7c39428ee 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -20,9 +20,10 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer import java.util -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.network.BlockFetchingListener -import org.apache.spark.serializer.Serializer +import org.apache.spark.network.netty.NettyMessages._ +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient} import org.apache.spark.storage.BlockId @@ -52,7 +53,6 @@ class NettyBlockFetcher( val chunkCallback = new ChunkReceivedCallback { // On receipt of a chunk, pass it upwards as a block. def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { - buffer.retain() listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index c8658ec98b82c..9206237256e0b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -22,18 +22,24 @@ import java.nio.ByteBuffer import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.serializer.Serializer -import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.network.client.{SluiceClient, RpcResponseCallback} import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{StorageLevel, BlockId} import scala.collection.JavaConversions._ -/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ -case class OpenBlocks(blockIds: Seq[BlockId]) +object NettyMessages { -/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ -case class ShuffleStreamHandle(streamId: Long, numChunks: Int) + /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ + case class OpenBlocks(blockIds: Seq[BlockId]) + + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ + case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) + + /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ + case class ShuffleStreamHandle(streamId: Long, numChunks: Int) +} /** * Serves requests to open blocks by simply registering one chunk per block requested. @@ -44,16 +50,27 @@ class NettyBlockRpcServer( blockManager: BlockDataManager) extends RpcHandler with Logging { - override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { + import NettyMessages._ + + override def receive( + client: SluiceClient, + messageBytes: Array[Byte], + responseContext: RpcResponseCallback): Unit = { val ser = serializer.newInstance() val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) logTrace(s"Received request: $message") + message match { case OpenBlocks(blockIds) => val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) val streamId = streamManager.registerStream(blocks.iterator) + logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess( ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + + case UploadBlock(blockId, blockData, level) => + blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + responseContext.onSuccess(new Array[Byte](0)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 7576d51e22175..6145c86c65617 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,24 +17,23 @@ package org.apache.spark.network.netty +import scala.concurrent.{Promise, Future} + import org.apache.spark.SparkConf import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory} -import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer} +import org.apache.spark.network.client.{RpcResponseCallback, SluiceClient, SluiceClientFactory} +import org.apache.spark.network.netty.NettyMessages.UploadBlock +import org.apache.spark.network.server._ import org.apache.spark.network.util.{ConfigProvider, SluiceConfig} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils -import scala.concurrent.Future - /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { - var client: SluiceClient = _ - // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. val serializer = new JavaSerializer(conf) @@ -42,22 +41,24 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { private[this] val sluiceConf = new SluiceConfig( new ConfigProvider { override def get(name: String) = conf.get(name) }) + private[this] var sluiceContext: SluiceContext = _ private[this] var server: SluiceServer = _ private[this] var clientFactory: SluiceClientFactory = _ override def init(blockDataManager: BlockDataManager): Unit = { val streamManager = new DefaultStreamManager val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) - server = new SluiceServer(sluiceConf, streamManager, rpcHandler) - clientFactory = new SluiceClientFactory(sluiceConf) + sluiceContext = new SluiceContext(sluiceConf, streamManager, rpcHandler) + clientFactory = sluiceContext.createClientFactory() + server = sluiceContext.createServer() } override def fetchBlocks( - hostName: String, + hostname: String, port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - val client = clientFactory.createClient(hostName, port) + val client = clientFactory.createClient(hostname, port) new NettyBlockFetcher(serializer, client, blockIds, listener) } @@ -65,13 +66,40 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { override def port: Int = server.getPort - // TODO: Implement override def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] = ??? + level: StorageLevel): Future[Unit] = { + val result = Promise[Unit]() + val client = clientFactory.createClient(hostname, port) + + // Convert or copy nio buffer into array in order to serialize it. + val nioBuffer = blockData.nioByteBuffer() + val array = if (nioBuffer.hasArray) { + nioBuffer.array() + } else { + val data = new Array[Byte](nioBuffer.remaining()) + nioBuffer.get(data) + data + } + + val ser = serializer.newInstance() + client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + logTrace(s"Successfully uploaded block $blockId") + result.success() + } + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading block $blockId", e) + result.failure(e) + } + }) + + result.future + } override def close(): Unit = server.close() } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index bce1069548437..e91f0af0e87a7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -127,12 +127,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) : Future[Unit] = { checkInit() - val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) + val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) val remoteCmId = new ConnectionManagerId(hostName, port) val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index a9144cdd97b8c..4024dea31845c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -17,14 +17,14 @@ package org.apache.spark.serializer -import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} /** * :: DeveloperApi :: @@ -142,3 +142,48 @@ abstract class DeserializationStream { } } } + + +class NoOpReadSerializer(conf: SparkConf) extends Serializer with Serializable { + override def newInstance(): SerializerInstance = { + new NoOpReadSerializerInstance() + } +} + +private[spark] class NoOpReadSerializerInstance() + extends SerializerInstance { + + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val out = serializeStream(bos) + out.writeObject(t) + out.close() + ByteBuffer.wrap(bos.toByteArray) + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + null.asInstanceOf[T] + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + null.asInstanceOf[T] + } + + override def serializeStream(s: OutputStream): SerializationStream = { + new JavaSerializationStream(s, 100) + } + + override def deserializeStream(s: InputStream): DeserializationStream = { + new NoOpDeserializationStream(s, Utils.getContextOrSparkClassLoader) + } + + def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { + new NoOpDeserializationStream(s, loader) + } +} + +private[spark] class NoOpDeserializationStream(in: InputStream, loader: ClassLoader) + extends DeserializationStream { + def readObject[T: ClassTag](): T = throw new EOFException() + def close() { } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4d8b5c1e1b084..6bbc49f9de829 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -855,9 +855,9 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioManagedBuffer(data), tLevel) - logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" - .format((System.currentTimeMillis - onePeerStartTime))) + peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" + .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer peersForReplication -= peer replicationFailed = false diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index e2d32c859bbda..f41c8d0315cb3 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) val akkaFailureDetector = conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index d7b2d2e1e330f..840d8273cb6a8 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,10 +24,10 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { - System.setProperty("spark.shuffle.use.netty", "true") + System.setProperty("spark.shuffle.blockTransferService", "netty") } override def afterAll() { - System.clearProperty("spark.shuffle.use.netty") + System.clearProperty("spark.shuffle.blockTransferService") } } diff --git a/network/common/src/main/java/org/apache/spark/network/SluiceContext.java b/network/common/src/main/java/org/apache/spark/network/SluiceContext.java new file mode 100644 index 0000000000000..7845ceb8b7d06 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/SluiceContext.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.Channel; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.protocol.response.MessageDecoder; +import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.SluiceChannelHandler; +import org.apache.spark.network.server.SluiceRequestHandler; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Contains the context to create a {@link SluiceServer}, {@link SluiceClientFactory}, and to setup + * Netty Channel pipelines with a {@link SluiceChannelHandler}. + * + * The SluiceServer and SluiceClientFactory both create a SluiceChannelHandler for each channel. + * As each SluiceChannelHandler contains a SluiceClient, this enables server processes to send + * messages back to the client on an existing channel. + */ +public class SluiceContext { + private final Logger logger = LoggerFactory.getLogger(SluiceContext.class); + + private final SluiceConfig conf; + private final StreamManager streamManager; + private final RpcHandler rpcHandler; + + private final MessageEncoder encoder; + private final MessageDecoder decoder; + + public SluiceContext(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { + this.conf = conf; + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + this.encoder = new MessageEncoder(); + this.decoder = new MessageDecoder(); + } + + public SluiceClientFactory createClientFactory() { + return new SluiceClientFactory(this); + } + + public SluiceServer createServer() { + return new SluiceServer(this); + } + + /** + * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and + * has a {@link SluiceChannelHandler} to handle request or response messages. + * + * @return Returns the created SluiceChannelHandler, which includes a SluiceClient that can be + * used to communicate on this channel. The SluiceClient is directly associated with a + * ChannelHandler to ensure all users of the same channel get the same SluiceClient object. + */ + public SluiceChannelHandler initializePipeline(SocketChannel channel) { + try { + SluiceChannelHandler channelHandler = createChannelHandler(channel); + channel.pipeline() + .addLast("encoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("decoder", decoder) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", channelHandler); + return channelHandler; + } catch (RuntimeException e) { + logger.error("Error while initializing Netty pipeline", e); + throw e; + } + } + + /** + * Creates the server- and client-side handler which is used to handle both RequestMessages and + * ResponseMessages. The channel is expected to have been successfully created, though certain + * properties (such as the remoteAddress()) may not be available yet. + */ + private SluiceChannelHandler createChannelHandler(Channel channel) { + SluiceResponseHandler responseHandler = new SluiceResponseHandler(channel); + SluiceClient client = new SluiceClient(channel, responseHandler); + SluiceRequestHandler requestHandler = new SluiceRequestHandler(channel, client, streamManager, + rpcHandler); + return new SluiceChannelHandler(client, responseHandler, requestHandler); + } + + public SluiceConfig getConf() { return conf; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java index 1f7d3b0234e38..d6d97981eebd6 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java @@ -19,7 +19,10 @@ import java.io.Closeable; import java.util.UUID; +import java.util.concurrent.TimeUnit; +import com.google.common.base.Preconditions; +import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.request.ChunkFetchRequest; import org.apache.spark.network.protocol.request.RpcRequest; +import org.apache.spark.network.util.NettyUtils; /** * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow @@ -50,7 +54,7 @@ * may be used for multiple streams, but any given stream must be restricted to a single client, * in order to avoid out-of-order responses. * - * NB: This class is used to make requests to the server, while {@link SluiceClientHandler} is + * NB: This class is used to make requests to the server, while {@link SluiceResponseHandler} is * responsible for handling responses from the server. * * Concurrency: thread safe and can be called from multiple threads. @@ -58,24 +62,16 @@ public class SluiceClient implements Closeable { private final Logger logger = LoggerFactory.getLogger(SluiceClient.class); - private final ChannelFuture cf; - private final SluiceClientHandler handler; + private final Channel channel; + private final SluiceResponseHandler handler; - private final String serverAddr; - - SluiceClient(ChannelFuture cf, SluiceClientHandler handler) { - this.cf = cf; - this.handler = handler; - - if (cf != null && cf.channel() != null && cf.channel().remoteAddress() != null) { - serverAddr = cf.channel().remoteAddress().toString(); - } else { - serverAddr = ""; - } + public SluiceClient(Channel channel, SluiceResponseHandler handler) { + this.channel = Preconditions.checkNotNull(channel); + this.handler = Preconditions.checkNotNull(handler); } public boolean isActive() { - return cf.channel().isActive(); + return channel.isOpen() || channel.isRegistered() || channel.isActive(); } /** @@ -97,28 +93,27 @@ public void fetchChunk( long streamId, final int chunkIndex, final ChunkReceivedCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); handler.addFetchRequest(streamChunkId, callback); - cf.channel().writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, - timeTaken); + timeTaken); } else { - // Fail all blocks. String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - serverAddr, future.cause().getMessage()); + serverAddr, future.cause()); logger.error(errorMsg, future.cause()); - future.cause().printStackTrace(); handler.removeFetchRequest(streamChunkId); - callback.onFailure(chunkIndex, new RuntimeException(errorMsg)); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); } } }); @@ -129,13 +124,14 @@ public void operationComplete(ChannelFuture future) throws Exception { * with the server's response or upon any failure. */ public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.debug("Sending RPC to {}", serverAddr); final long tag = UUID.randomUUID().getLeastSignificantBits(); handler.addRpcRequest(tag, callback); - cf.channel().writeAndFlush(new RpcRequest(tag, message)).addListener( + channel.writeAndFlush(new RpcRequest(tag, message)).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -143,12 +139,11 @@ public void operationComplete(ChannelFuture future) throws Exception { long timeTaken = System.currentTimeMillis() - startTime; logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); } else { - // Fail all blocks. - String errorMsg = String.format("Failed to send request %s to %s: %s", tag, - serverAddr, future.cause().getMessage()); + String errorMsg = String.format("Failed to send RPC %s to %s: %s", tag, + serverAddr, future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(tag); - callback.onFailure(new RuntimeException(errorMsg)); + callback.onFailure(new RuntimeException(errorMsg, future.cause())); } } }); @@ -156,6 +151,7 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public void close() { - cf.channel().close(); + // close is a local operation and should finish with milliseconds; timeout just to be safe + channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java index 17491dc3f8720..5de998ef6ed55 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -21,7 +21,6 @@ import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeoutException; @@ -37,8 +36,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.request.ClientRequestEncoder; -import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.SluiceContext; +import org.apache.spark.network.protocol.response.MessageDecoder; +import org.apache.spark.network.protocol.response.MessageEncoder; +import org.apache.spark.network.server.SluiceChannelHandler; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.SluiceConfig; @@ -53,19 +54,17 @@ public class SluiceClientFactory implements Closeable { private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class); + private final SluiceContext context; private final SluiceConfig conf; - private final Map connectionPool; - private final ClientRequestEncoder encoder; - private final ServerResponseDecoder decoder; + private final ConcurrentHashMap connectionPool; private final Class socketChannelClass; private final EventLoopGroup workerGroup; - public SluiceClientFactory(SluiceConfig conf) { - this.conf = conf; + public SluiceClientFactory(SluiceContext context) { + this.context = context; + this.conf = context.getConf(); this.connectionPool = new ConcurrentHashMap(); - this.encoder = new ClientRequestEncoder(); - this.decoder = new ServerResponseDecoder(); IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); @@ -82,18 +81,18 @@ public SluiceClientFactory(SluiceConfig conf) { public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. - InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); SluiceClient cachedClient = connectionPool.get(address); if (cachedClient != null && cachedClient.isActive()) { + System.out.println("Reusing cached client: " + cachedClient); return cachedClient; + } else if (cachedClient != null) { + connectionPool.remove(address, cachedClient); // Remove inactive clients. } + System.out.println("Creating new client: " + cachedClient); logger.debug("Creating new connection to " + address); - // There is a chance two threads are creating two different clients connecting to the same host. - // But that's probably ok, as long as the caller hangs on to their client for a single stream. - final SluiceClientHandler handler = new SluiceClientHandler(); - Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup) .channel(socketChannelClass) @@ -108,11 +107,14 @@ public SluiceClient createClient(String remoteHost, int remotePort) throws Timeo bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - ch.pipeline() - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", NettyUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler); + SluiceChannelHandler channelHandler = context.initializePipeline(ch); + SluiceClient oldClient = connectionPool.putIfAbsent(address, channelHandler.getClient()); + if (oldClient != null) { + logger.debug("Two clients were created concurrently, second one will be disposed."); + ch.close(); + // Note: this type of failure is still considered a success by Netty, and thus the + // ChannelFuture will complete successfully. + } } }); @@ -120,11 +122,18 @@ public void initChannel(SocketChannel ch) { ChannelFuture cf = bootstrap.connect(address); if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { throw new TimeoutException( - String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } - SluiceClient client = new SluiceClient(cf, handler); - connectionPool.put(address, client); + SluiceClient client = connectionPool.get(address); + if (client == null) { + // The only way we should be able to reach here is if the client we created started out + // in the "inactive" state, and someone else simultaneously tried to create another client to + // the same server. This is an error condition, as the first client failed to connect. + throw new IllegalStateException("Client was unset! Must have been immediately inactive."); + } else if (!client.isActive()) { + throw new IllegalStateException("Failed to create active client."); + } return client; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java similarity index 75% rename from network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java rename to network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java index ed20b032931c3..9fbd487da86a7 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceResponseHandler.java @@ -17,37 +17,43 @@ package org.apache.spark.network.client; -import java.net.SocketAddress; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.google.common.annotations.VisibleForTesting; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.protocol.response.ResponseMessage; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; import org.apache.spark.network.protocol.response.RpcFailure; import org.apache.spark.network.protocol.response.RpcResponse; -import org.apache.spark.network.protocol.response.ServerResponse; +import org.apache.spark.network.server.MessageHandler; +import org.apache.spark.network.util.NettyUtils; /** - * Handler that processes server responses, in response to requests issued from [[SluiceClient]]. + * Handler that processes server responses, in response to requests issued from a [[SluiceClient]]. * It works by tracking the list of outstanding requests (and their callbacks). * * Concurrency: thread safe and can be called from multiple threads. */ -public class SluiceClientHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceClientHandler.class); +public class SluiceResponseHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceResponseHandler.class); - private final Map outstandingFetches = - new ConcurrentHashMap(); + private final Channel channel; - private final Map outstandingRpcs = - new ConcurrentHashMap(); + private final Map outstandingFetches; + + private final Map outstandingRpcs; + + public SluiceResponseHandler(Channel channel) { + this.channel = channel; + this.outstandingFetches = new ConcurrentHashMap(); + this.outstandingRpcs = new ConcurrentHashMap(); + } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { outstandingFetches.put(streamChunkId, callback); @@ -73,41 +79,36 @@ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { entry.getValue().onFailure(entry.getKey().chunkIndex, cause); } - // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests - // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelUnregistered() { if (outstandingFetches.size() > 0) { - SocketAddress remoteAddress = ctx.channel().remoteAddress(); + String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when contention from {} is closed", outstandingFetches.size(), remoteAddress); failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); } - super.channelUnregistered(ctx); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + public void exceptionCaught(Throwable cause) { if (outstandingFetches.size() > 0) { - logger.error(String.format("Exception in connection from %s: %s", - ctx.channel().remoteAddress(), cause.getMessage()), cause); failOutstandingRequests(cause); } - ctx.close(); } @Override - public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { - String server = ctx.channel().remoteAddress().toString(); + public void handle(ResponseMessage message) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Got a response for block {} from {} but it is not outstanding", - resp.streamChunkId, server); + resp.streamChunkId, remoteAddress); resp.buffer.release(); } else { outstandingFetches.remove(resp.streamChunkId); @@ -119,7 +120,7 @@ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", - resp.streamChunkId, server, resp.errorString); + resp.streamChunkId, remoteAddress, resp.errorString); } else { outstandingFetches.remove(resp.streamChunkId); listener.onFailure(resp.streamChunkId.chunkIndex, @@ -130,7 +131,7 @@ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { RpcResponseCallback listener = outstandingRpcs.get(resp.tag); if (listener == null) { logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", - resp.tag, server, resp.response.length); + resp.tag, remoteAddress, resp.response.length); } else { outstandingRpcs.remove(resp.tag); listener.onSuccess(resp.response); @@ -140,11 +141,13 @@ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { RpcResponseCallback listener = outstandingRpcs.get(resp.tag); if (listener == null) { logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", - resp.tag, server, resp.errorString); + resp.tag, remoteAddress, resp.errorString); } else { outstandingRpcs.remove(resp.tag); listener.onFailure(new RuntimeException(resp.errorString)); } + } else { + throw new IllegalStateException("Unknown response type: " + message.type()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java similarity index 67% rename from network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java rename to network/common/src/main/java/org/apache/spark/network/protocol/Message.java index db075c44b4cda..6731b3f53ae82 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -15,28 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.protocol.request; +package org.apache.spark.network.protocol; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.Encodable; - /** Messages from the client to the server. */ -public interface ClientRequest extends Encodable { +public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); - /** - * Preceding every serialized ClientRequest is the type, which allows us to deserialize - * the request. - */ + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { - ChunkFetchRequest(0), RpcRequest(1); + ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), + RpcRequest(3), RpcResponse(4), RpcFailure(5); private final byte id; private Type(int id) { - assert id < 128 : "Cannot have more than 128 request types"; + assert id < 128 : "Cannot have more than 128 message types"; this.id = (byte) id; } @@ -48,10 +44,14 @@ private Type(int id) { public static Type decode(ByteBuf buf) { byte id = buf.readByte(); - switch(id) { + switch (id) { case 0: return ChunkFetchRequest; - case 1: return RpcRequest; - default: throw new IllegalArgumentException("Unknown request type: " + id); + case 1: return ChunkFetchSuccess; + case 2: return ChunkFetchFailure; + case 3: return RpcRequest; + case 4: return RpcResponse; + case 5: return RpcFailure; + default: throw new IllegalArgumentException("Unknown message type: " + id); } } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java index a79eb363cf58c..99cbb8777a873 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java @@ -24,9 +24,9 @@ /** * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single - * {@link org.apache.spark.network.protocol.response.ServerResponse} (either success or failure). + * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements ClientRequest { +public final class ChunkFetchRequest implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java deleted file mode 100644 index a937da4cecae0..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol.request; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageDecoder; - -/** - * Decoder in the server side to decode client requests. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * {@link org.apache.spark.network.util.NettyUtils#createFrameDecoder()}. - */ -@ChannelHandler.Sharable -public final class ClientRequestDecoder extends MessageToMessageDecoder { - - @Override - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - ClientRequest.Type msgType = ClientRequest.Type.decode(in); - ClientRequest decoded = decode(msgType, in); - assert decoded.type() == msgType; - assert in.readableBytes() == 0; - out.add(decoded); - } - - private ClientRequest decode(ClientRequest.Type msgType, ByteBuf in) { - switch (msgType) { - case ChunkFetchRequest: - return ChunkFetchRequest.decode(in); - - case RpcRequest: - return RpcRequest.decode(in); - - default: throw new IllegalArgumentException("Unexpected message type: " + msgType); - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java deleted file mode 100644 index bcff4a0a25568..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol.request; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageEncoder; - -/** - * Encoder for {@link ClientRequest} used in client side. - * - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@ChannelHandler.Sharable -public final class ClientRequestEncoder extends MessageToMessageEncoder { - @Override - public void encode(ChannelHandlerContext ctx, ClientRequest in, List out) { - ClientRequest.Type msgType = in.type(); - // Write 8 bytes for the frame's length, followed by the request type and request itself. - int frameLength = 8 + msgType.encodedLength() + in.encodedLength(); - ByteBuf buf = ctx.alloc().buffer(frameLength); - buf.writeLong(frameLength); - msgType.encode(buf); - in.encode(buf); - assert buf.writableBytes() == 0; - out.add(buf); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java new file mode 100644 index 0000000000000..58abce25d9a2a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RequestMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the client to the server. */ +public interface RequestMessage extends Message { + // token interface +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java index 126370330f723..810da7a689c13 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java @@ -24,10 +24,10 @@ /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. - * This will correspond to a single {@link org.apache.spark.network.protocol.response.ServerResponse} - * (either success or failure). + * This will correspond to a single + * {@link org.apache.spark.network.protocol.response.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements ClientRequest { +public final class RpcRequest implements RequestMessage { /** Tag is used to link an RPC request with its response. */ public final long tag; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java index 3a57d71b4f3ea..18ed4d95bba4c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java @@ -17,6 +17,7 @@ package org.apache.spark.network.protocol.response; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -26,7 +27,7 @@ * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an * error fetching the chunk. */ -public final class ChunkFetchFailure implements ServerResponse { +public final class ChunkFetchFailure implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; @@ -40,13 +41,13 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + 4 + errorString.getBytes().length; + return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); - byte[] errorBytes = errorString.getBytes(); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); buf.writeInt(errorBytes.length); buf.writeBytes(errorBytes); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java index 874dc4f5940cf..6bc26a64b9945 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java @@ -32,7 +32,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess implements ServerResponse { +public final class ChunkFetchSuccess implements ResponseMessage { public final StreamChunkId streamChunkId; public final ManagedBuffer buffer; @@ -49,7 +49,7 @@ public int encodedLength() { return streamChunkId.encodedLength(); } - /** Encoding does NOT include buffer itself. See {@link ServerResponseEncoder}. */ + /** Encoding does NOT include buffer itself. See {@link MessageEncoder}. */ @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java similarity index 70% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java index e06198284e620..3ae80305803eb 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageDecoder.java @@ -23,30 +23,44 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.MessageToMessageDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.RpcRequest; /** * Decoder used by the client side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class ServerResponseDecoder extends MessageToMessageDecoder { +public final class MessageDecoder extends MessageToMessageDecoder { + private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - ServerResponse.Type msgType = ServerResponse.Type.decode(in); - ServerResponse decoded = decode(msgType, in); + Message.Type msgType = Message.Type.decode(in); + Message decoded = decode(msgType, in); assert decoded.type() == msgType; + logger.debug("Received message " + msgType + ": " + decoded); out.add(decoded); } - private ServerResponse decode(ServerResponse.Type msgType, ByteBuf in) { + private Message decode(Message.Type msgType, ByteBuf in) { switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + case ChunkFetchSuccess: return ChunkFetchSuccess.decode(in); case ChunkFetchFailure: return ChunkFetchFailure.decode(in); + case RpcRequest: + return RpcRequest.decode(in); + case RpcResponse: return RpcResponse.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java similarity index 78% rename from network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java rename to network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java index 069f42463a8fe..5ca8de42a6429 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/MessageEncoder.java @@ -26,17 +26,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.protocol.Message; + /** * Encoder used by the server side to encode server-to-client responses. * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class ServerResponseEncoder extends MessageToMessageEncoder { +public final class MessageEncoder extends MessageToMessageEncoder { - private final Logger logger = LoggerFactory.getLogger(ServerResponseEncoder.class); + private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + /*** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out', in order to enable zero-copy transfer. + */ @Override - public void encode(ChannelHandlerContext ctx, ServerResponse in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) { Object body = null; long bodyLength = 0; @@ -56,7 +64,7 @@ public void encode(ChannelHandlerContext ctx, ServerResponse in, List ou } } - ServerResponse.Type msgType = in.type(); + Message.Type msgType = in.type(); // All messages have the frame length, message type, and message itself. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); long frameLength = headerLength + bodyLength; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java new file mode 100644 index 0000000000000..8f545e91d1d8e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ResponseMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the server to the client. */ +public interface ResponseMessage extends Message { + // token interface +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java index 274920b28bced..6b71da5708c58 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java @@ -17,11 +17,12 @@ package org.apache.spark.network.protocol.response; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; /** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ServerResponse { +public final class RpcFailure implements ResponseMessage { public final long tag; public final String errorString; @@ -35,13 +36,13 @@ public RpcFailure(long tag, String errorString) { @Override public int encodedLength() { - return 8 + 4 + errorString.getBytes().length; + return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; } @Override public void encode(ByteBuf buf) { buf.writeLong(tag); - byte[] errorBytes = errorString.getBytes(); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); buf.writeInt(errorBytes.length); buf.writeBytes(errorBytes); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java index 0c6f8acdcdc4b..40623ce31c666 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java @@ -23,7 +23,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ServerResponse { +public final class RpcResponse implements ResponseMessage { public final long tag; public final byte[] response; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java deleted file mode 100644 index 335f9e8ea69f9..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol.response; - -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encodable; - -/** - * Messages from server to client (usually in response to some - * {@link org.apache.spark.network.protocol.request.ClientRequest}. - */ -public interface ServerResponse extends Encodable { - /** Used to identify this response type. */ - Type type(); - - /** - * Preceding every serialized ServerResponse is the type, which allows us to deserialize - * the response. - */ - public static enum Type implements Encodable { - ChunkFetchSuccess(0), ChunkFetchFailure(1), RpcResponse(2), RpcFailure(3); - - private final byte id; - - private Type(int id) { - assert id < 128 : "Cannot have more than 128 response types"; - this.id = (byte) id; - } - - public byte id() { return id; } - - @Override public int encodedLength() { return 1; } - - @Override public void encode(ByteBuf buf) { buf.writeByte(id); } - - public static Type decode(ByteBuf buf) { - byte id = buf.readByte(); - switch(id) { - case 0: return ChunkFetchSuccess; - case 1: return ChunkFetchFailure; - case 2: return RpcResponse; - case 3: return RpcFailure; - default: throw new IllegalArgumentException("Unknown response type: " + id); - } - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java index 04814d9a88c4a..d93607a7c31ea 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -23,6 +23,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -30,6 +33,8 @@ * fetched as chunks by the client. */ public class DefaultStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class); + private final AtomicLong nextStreamId; private final Map streams; @@ -61,7 +66,14 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { "Requested chunk index beyond end %s", chunkIndex)); } state.curChunk += 1; - return state.buffers.next(); + ManagedBuffer nextChunk = state.buffers.next(); + + if (!state.buffers.hasNext()) { + logger.trace("Removing stream id {}", streamId); + streams.remove(streamId); + } + + return nextChunk; } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java new file mode 100644 index 0000000000000..b80c15106ecbd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.protocol.Message; + +/** + * Handles either request or response messages coming off of Netty. A MessageHandler instance + * is associated with a single Netty Channel (though it may have multiple clients on the same + * Channel.) + */ +public abstract class MessageHandler { + /** Handles the receipt of a single message. */ + public abstract void handle(T message); + + /** Invoked when an exception was caught on the Channel. */ + public abstract void exceptionCaught(Throwable cause); + + /** Invoked when the channel this MessageHandler is on has been unregistered. */ + public abstract void channelUnregistered(); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index abfbe66d875e8..5700cc83bd9c8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -18,6 +18,7 @@ package org.apache.spark.network.server; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.SluiceClient; /** * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.SluiceClient}s. @@ -26,6 +27,12 @@ public interface RpcHandler { /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. + * @param message The serialized bytes of the RPC. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. */ - void receive(byte[] message, RpcResponseCallback callback); + void receive(SluiceClient client, byte[] message, RpcResponseCallback callback); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java new file mode 100644 index 0000000000000..d5a91ec1b6c28 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceChannelHandler.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.request.RequestMessage; +import org.apache.spark.network.protocol.response.ResponseMessage; +import org.apache.spark.network.util.NettyUtils; + +/** + * A handler which is used for delegating requests to the + * {@link org.apache.spark.network.server.SluiceRequestHandler} and responses to the + * {@link org.apache.spark.network.client.SluiceResponseHandler}. + * + * All channels created in Sluice are bidirectional. When the Client initiates a Netty Channel + * with a RequestMessage (which gets handled by the Server's RequestHandler), the Server will + * produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server also + * gets a handle on the same Channel, so it may then begin to send RequestMessages to the Client. + * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, + * for the Client's responses to the Server's requests. + */ +public class SluiceChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceChannelHandler.class); + + private final SluiceClient client; + private final SluiceResponseHandler responseHandler; + private final SluiceRequestHandler requestHandler; + + public SluiceChannelHandler( + SluiceClient client, + SluiceResponseHandler responseHandler, + SluiceRequestHandler requestHandler) { + this.client = client; + this.responseHandler = responseHandler; + this.requestHandler = requestHandler; + } + + public SluiceClient getClient() { + return client; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + cause); + requestHandler.exceptionCaught(cause); + responseHandler.exceptionCaught(cause); + ctx.close(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + requestHandler.channelUnregistered(); + responseHandler.channelUnregistered(); + super.channelUnregistered(ctx); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) { + if (request instanceof RequestMessage) { + requestHandler.handle((RequestMessage) request); + } else { + responseHandler.handle((ResponseMessage) request); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java similarity index 65% rename from network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java rename to network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java index fad72fbfc711b..5f5111e0a7638 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceRequestHandler.java @@ -21,33 +21,40 @@ import com.google.common.base.Throwables; import com.google.common.collect.Sets; +import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.SluiceClient; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.request.RequestMessage; import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.ClientRequest; import org.apache.spark.network.protocol.request.RpcRequest; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; import org.apache.spark.network.protocol.response.RpcFailure; import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.util.NettyUtils; /** - * A handler that processes requests from clients and writes chunk data back. Each handler keeps - * track of which streams have been fetched via this channel, in order to clean them up if the - * channel is terminated (see #channelUnregistered). + * A handler that processes requests from clients and writes chunk data back. Each handler is + * attached to a single Netty channel, and keeps track of which streams have been fetched via this + * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). * * The messages should have been processed by the pipeline setup by {@link SluiceServer}. */ -public class SluiceServerHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(SluiceServerHandler.class); +public class SluiceRequestHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceRequestHandler.class); + + /** The Netty channel that this handler is associated with. */ + private final Channel channel; + + /** Client on the same channel allowing us to talk back to the requester. */ + private final SluiceClient reverseClient; /** Returns each chunk part of a stream. */ private final StreamManager streamManager; @@ -58,22 +65,24 @@ public class SluiceServerHandler extends SimpleChannelInboundHandler streamIds; - public SluiceServerHandler(StreamManager streamManager, RpcHandler rpcHandler) { + public SluiceRequestHandler( + Channel channel, + SluiceClient reverseClient, + StreamManager streamManager, + RpcHandler rpcHandler) { + this.channel = channel; + this.reverseClient = reverseClient; this.streamManager = streamManager; this.rpcHandler = rpcHandler; this.streamIds = Sets.newHashSet(); } @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.error("Exception in connection from " + ctx.channel().remoteAddress(), cause); - ctx.close(); - super.exceptionCaught(ctx, cause); + public void exceptionCaught(Throwable cause) { } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { - super.channelUnregistered(ctx); + public void channelUnregistered() { // Inform the StreamManager that these streams will no longer be read from. for (long streamId : streamIds) { streamManager.connectionTerminated(streamId); @@ -81,18 +90,18 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, ClientRequest request) { + public void handle(RequestMessage request) { if (request instanceof ChunkFetchRequest) { - processFetchRequest(ctx, (ChunkFetchRequest) request); + processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { - processRpcRequest(ctx, (RpcRequest) request); + processRpcRequest((RpcRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } } - private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFetchRequest req) { - final String client = ctx.channel().remoteAddress().toString(); + private void processFetchRequest(final ChunkFetchRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); @@ -103,29 +112,29 @@ private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFet } catch (Exception e) { logger.error(String.format( "Error opening block %s for request from %s", req.streamChunkId, client), e); - respond(ctx, new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); return; } - respond(ctx, new ChunkFetchSuccess(req.streamChunkId, buf)); + respond(new ChunkFetchSuccess(req.streamChunkId, buf)); } - private void processRpcRequest(final ChannelHandlerContext ctx, final RpcRequest req) { + private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { - respond(ctx, new RpcResponse(req.tag, response)); + respond(new RpcResponse(req.tag, response)); } @Override public void onFailure(Throwable e) { - respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); } }); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); - respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + respond(new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); } } @@ -133,9 +142,9 @@ public void onFailure(Throwable e) { * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. */ - private void respond(final ChannelHandlerContext ctx, final Encodable result) { - final String remoteAddress = ctx.channel().remoteAddress().toString(); - ctx.writeAndFlush(result).addListener( + private void respond(final Encodable result) { + final String remoteAddress = channel.remoteAddress().toString(); + channel.writeAndFlush(result).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -144,7 +153,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } else { logger.error(String.format("Error sending result %s to %s; closing connection", result, remoteAddress), future.cause()); - ctx.close(); + channel.close(); } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java index aa81271024156..965db536a2782 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java @@ -19,6 +19,7 @@ import java.io.Closeable; import java.net.InetSocketAddress; +import java.util.concurrent.TimeUnit; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; @@ -30,8 +31,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.request.ClientRequestDecoder; -import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.SluiceContext; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.SluiceConfig; @@ -42,18 +42,16 @@ public class SluiceServer implements Closeable { private final Logger logger = LoggerFactory.getLogger(SluiceServer.class); + private final SluiceContext context; private final SluiceConfig conf; - private final StreamManager streamManager; - private final RpcHandler rpcHandler; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port; - public SluiceServer(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { - this.conf = conf; - this.streamManager = streamManager; - this.rpcHandler = rpcHandler; + public SluiceServer(SluiceContext context) { + this.context = context; + this.conf = context.getConf(); init(); } @@ -86,16 +84,9 @@ private void init() { } bootstrap.childHandler(new ChannelInitializer() { - @Override protected void initChannel(SocketChannel ch) throws Exception { - ch.pipeline() - .addLast("frameDecoder", NettyUtils.createFrameDecoder()) - .addLast("clientRequestDecoder", new ClientRequestDecoder()) - .addLast("serverResponseEncoder", new ServerResponseEncoder()) - // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this - // would require more logic to guarantee if this were not part of the same event loop. - .addLast("handler", new SluiceServerHandler(streamManager, rpcHandler)); + context.initializePipeline(ch); } }); @@ -109,7 +100,8 @@ protected void initChannel(SocketChannel ch) throws Exception { @Override public void close() { if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly(); + // close is a local operation and should finish with milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); channelFuture = null; } if (bootstrap != null && bootstrap.group() != null) { diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 2e07f5a270cb9..47b74b229fdec 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -21,7 +21,7 @@ /** * The StreamManager is used to fetch individual chunks from a stream. This is used in - * {@link SluiceServerHandler} in order to respond to fetchChunk() requests. Creation of the + * {@link SluiceRequestHandler} in order to respond to fetchChunk() requests. Creation of the * stream is outside the scope of Sluice, but a given stream is guaranteed to be read by only one * client connection, meaning that getChunk() for a particular stream will be called serially and * that once the connection associated with the stream is closed, that stream will never be used diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java index 91cb3e0e6f8f6..c0aa12c81ba64 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -19,7 +19,7 @@ /** * Selector for which form of low-level IO we should use. - * NIO is always available, while EPOLL is only available on certain machines. + * NIO is always available, while EPOLL is only available on Linux. * AUTO is used to select EPOLL if it's available, or NIO otherwise. */ public enum IOMode { diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index fafdcad04aeb6..32ba3f5b07f7a 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,13 +18,21 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.IOException; import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class JavaUtils { + private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + /** Closes the given object, ignoring IOExceptions. */ - @SuppressWarnings("deprecation") - public static void closeQuietly(Closeable closable) { - Closeables.closeQuietly(closable); + public static void closeQuietly(Closeable closeable) { + try { + closeable.close(); + } catch (IOException e) { + logger.error("IOException should not have been thrown.", e); + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 3d20dc9e1c1cd..a925c05469d3c 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -44,11 +44,11 @@ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String } ThreadFactory threadFactory = new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat(threadPrefix + "-%d") - .build(); + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); - switch(mode) { + switch (mode) { case NIO: return new NioEventLoopGroup(numThreads, threadFactory); case EPOLL: @@ -63,7 +63,7 @@ public static Class getClientChannelClass(IOMode mode) { if (mode == IOMode.AUTO) { mode = autoselectMode(); } - switch(mode) { + switch (mode) { case NIO: return NioSocketChannel.class; case EPOLL: @@ -78,7 +78,7 @@ public static Class getServerChannelClass(IOMode mode) if (mode == IOMode.AUTO) { mode = autoselectMode(); } - switch(mode) { + switch (mode) { case NIO: return NioServerSocketChannel.class; case EPOLL: @@ -101,9 +101,16 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } + /** Returns the remote address on the channel or "" if none exists. */ + public static String getRemoteAddress(Channel channel) { + if (channel != null && channel.remoteAddress() != null) { + return channel.remoteAddress().toString(); + } + return ""; + } + /** Returns EPOLL if it's available on this system, NIO otherwise. */ private static IOMode autoselectMode() { return Epoll.isAvailable() ? IOMode.EPOLL : IOMode.NIO; } } - diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java index d20528558cae1..d38f6db99c09b 100644 --- a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java @@ -94,8 +94,9 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } } }; - server = new SluiceServer(conf, streamManager, new NoOpRpcHandler()); - clientFactory = new SluiceClientFactory(conf); + SluiceContext context = new SluiceContext(conf, streamManager, new NoOpRpcHandler()); + server = context.createServer(); + clientFactory = context.createClientFactory(); } @AfterClass @@ -118,6 +119,7 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { + System.out.println("----------------------------------------------------------------"); SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); @@ -170,6 +172,14 @@ public void fetchFileChunk() throws Exception { res.releaseBuffers(); } + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + @Test public void fetchBothChunks() throws Exception { FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); @@ -179,14 +189,6 @@ public void fetchBothChunks() throws Exception { res.releaseBuffers(); } - @Test - public void fetchNonExistentChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(12345)); - assertTrue(res.successChunks.isEmpty()); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); - assertTrue(res.buffers.isEmpty()); - } - @Test public void fetchChunkAndNonExistent() throws Exception { FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java index af35709319957..ccfb7576afadb 100644 --- a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -17,10 +17,11 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.client.SluiceClient; public class NoOpRpcHandler implements RpcHandler { @Override - public void receive(byte[] message, RpcResponseCallback callback) { + public void receive(SluiceClient client, byte[] message, RpcResponseCallback callback) { callback.onSuccess(new byte[0]); } } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index cf74a9d8993fe..d2476e7f2ac22 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -22,25 +22,22 @@ import static org.junit.Assert.assertEquals; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.request.ChunkFetchRequest; -import org.apache.spark.network.protocol.request.ClientRequest; -import org.apache.spark.network.protocol.request.ClientRequestDecoder; -import org.apache.spark.network.protocol.request.ClientRequestEncoder; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; -import org.apache.spark.network.protocol.response.ServerResponse; -import org.apache.spark.network.protocol.response.ServerResponseDecoder; -import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.protocol.response.MessageDecoder; +import org.apache.spark.network.protocol.response.MessageEncoder; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { - private void testServerToClient(ServerResponse msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new ServerResponseEncoder()); + private void testServerToClient(Message msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new ServerResponseDecoder()); + NettyUtils.createFrameDecoder(), new MessageDecoder()); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeInbound(serverChannel.readOutbound()); @@ -50,12 +47,12 @@ private void testServerToClient(ServerResponse msg) { assertEquals(msg, clientChannel.readInbound()); } - private void testClientToServer(ClientRequest msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new ClientRequestEncoder()); + private void testClientToServer(Message msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new ClientRequestDecoder()); + NettyUtils.createFrameDecoder(), new MessageDecoder()); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeInbound(clientChannel.readOutbound()); diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java index e6b59b9ad8e5c..219d6cc998bd7 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java @@ -38,6 +38,7 @@ public class SluiceClientFactorySuite { private SluiceConfig conf; + private SluiceContext context; private SluiceServer server1; private SluiceServer server2; @@ -46,8 +47,9 @@ public void setUp() { conf = new SluiceConfig(new DefaultConfigProvider()); StreamManager streamManager = new DefaultStreamManager(); RpcHandler rpcHandler = new NoOpRpcHandler(); - server1 = new SluiceServer(conf, streamManager, rpcHandler); - server2 = new SluiceServer(conf, streamManager, rpcHandler); + context = new SluiceContext(conf, streamManager, rpcHandler); + server1 = context.createServer(); + server2 = context.createServer(); } @After @@ -58,7 +60,7 @@ public void tearDown() { @Test public void createAndReuseBlockClients() throws TimeoutException { - SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClientFactory factory = context.createClientFactory(); SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); SluiceClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); @@ -71,7 +73,7 @@ public void createAndReuseBlockClients() throws TimeoutException { @Test public void neverReturnInactiveClients() throws Exception { - SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClientFactory factory = context.createClientFactory(); SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); @@ -89,7 +91,7 @@ public void neverReturnInactiveClients() throws Exception { @Test public void closeBlockClientsWithFactory() throws TimeoutException { - SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClientFactory factory = context.createClientFactory(); SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java index cab0597fb948a..c665f2313c589 100644 --- a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java @@ -18,17 +18,17 @@ package org.apache.spark.network; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalChannel; import org.junit.Test; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.SluiceClientHandler; +import org.apache.spark.network.client.SluiceResponseHandler; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.response.ChunkFetchFailure; import org.apache.spark.network.protocol.response.ChunkFetchSuccess; @@ -38,53 +38,45 @@ public class SluiceClientHandlerSuite { public void handleSuccessfulFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); - SluiceClientHandler handler = new SluiceClientHandler(); + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); - EmbeddedChannel channel = new EmbeddedChannel(handler); - - channel.writeInbound(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); assertEquals(0, handler.numOutstandingRequests()); - assertFalse(channel.finish()); } @Test public void handleFailedFetch() { StreamChunkId streamChunkId = new StreamChunkId(1, 0); - SluiceClientHandler handler = new SluiceClientHandler(); + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); - EmbeddedChannel channel = new EmbeddedChannel(handler); - channel.writeInbound(new ChunkFetchFailure(streamChunkId, "some error msg")); + handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); - assertFalse(channel.finish()); } @Test public void clearAllOutstandingRequests() { - SluiceClientHandler handler = new SluiceClientHandler(); + SluiceResponseHandler handler = new SluiceResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); handler.addFetchRequest(new StreamChunkId(1, 1), callback); handler.addFetchRequest(new StreamChunkId(1, 2), callback); assertEquals(3, handler.numOutstandingRequests()); - EmbeddedChannel channel = new EmbeddedChannel(handler); - - channel.writeInbound(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); - channel.pipeline().fireExceptionCaught(new Exception("duh duh duhhhh")); + handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.exceptionCaught(new Exception("duh duh duhhhh")); // should fail both b2 and b3 verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); - assertFalse(channel.finish()); } }