From 29c6dcfaacb2e8b1f0582c6d5e435349c52e29af Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 5 Oct 2014 17:58:43 -0700 Subject: [PATCH] [SPARK-3453] Netty-based BlockTransferService, extracted from Spark core This PR encapsulates #2330, which is itself a continuation of #2240. The first goal of this PR is to provide an alternate, simpler implementation of the ConnectionManager which is based on Netty. In addition to this goal, however, we want to resolve [SPARK-3796](https://issues.apache.org/jira/browse/SPARK-3796), which calls for a standalone shuffle service which can be integrated into the YARN NodeManager, Standalone Worker, or on its own. This PR makes the first step in this direction by ensuring that the actual Netty service is as small as possible and extracted from Spark core. Given this, we should be able to construct this standalone jar which can be included in other JVMs without incurring significant dependency or runtime issues. The actual work to ensure that such a standalone shuffle service would work in Spark will be left for a future PR, however. In order to minimize dependencies and allow for the service to be long-running (possibly much longer-running than Spark, and possibly having to support multiple version of Spark simultaneously), the entire service has been ported to Java, where we have full control over the binary compatibility of the components and do not depend on the Scala runtime or version. These PRs have been addressed by folding in #2330: SPARK-3453: Refactor Netty module to use BlockTransferService interface SPARK-3018: Release all buffers upon task completion/failure SPARK-3002: Create a connection pool and reuse clients across different threads SPARK-3017: Integration tests and unit tests for connection failures SPARK-3049: Make sure client doesn't block when server/connection has error(s) SPARK-3502: SO_RCVBUF and SO_SNDBUF should be bootstrap childOption, not option SPARK-3503: Disable thread local cache in PooledByteBufAllocator TODO before mergeable: [ ] Implement uploadBlock() [ ] Unit tests for RPC side of code [ ] Performance testing [ ] Turn OFF by default (currently on for unit testing) --- core/pom.xml | 5 + .../scala/org/apache/spark/SparkEnv.scala | 17 +- .../spark/network/BlockDataManager.scala | 8 +- .../spark/network/BlockFetchingListener.scala | 2 + .../spark/network/BlockTransferService.scala | 13 +- .../apache/spark/network/ManagedBuffer.scala | 187 ---------- .../spark/network/netty/BlockClient.scala | 125 ------- .../network/netty/BlockClientFactory.scala | 175 ---------- .../network/netty/BlockClientHandler.scala | 138 -------- .../spark/network/netty/BlockServer.scala | 127 ------- .../network/netty/BlockServerHandler.scala | 125 ------- .../network/netty/NettyBlockFetcher.scala | 92 +++++ .../network/netty/NettyBlockRpcServer.scala | 59 ++++ .../netty/NettyBlockTransferService.scala | 69 ++-- .../apache/spark/network/netty/protocol.scala | 326 ------------------ .../network/nio/NioBlockTransferService.scala | 18 +- .../shuffle/FileShuffleBlockManager.scala | 6 +- .../shuffle/IndexShuffleBlockManager.scala | 2 +- .../spark/shuffle/ShuffleBlockManager.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 26 +- .../storage/BlockNotFoundException.scala | 1 - .../storage/ShuffleBlockFetcherIterator.scala | 11 +- .../apache/spark/storage/StorageLevel.scala | 3 +- .../netty/BlockClientFactorySuite.scala | 91 ----- .../netty/BlockClientHandlerSuite.scala | 114 ------ .../spark/network/netty/ProtocolSuite.scala | 113 ------ .../netty/ServerClientIntegrationSuite.scala | 174 ---------- .../network/netty/TestManagedBuffer.scala | 72 ---- .../hash/HashShuffleManagerSuite.scala | 8 +- .../ShuffleBlockFetcherIteratorSuite.scala | 3 +- network/common/pom.xml | 94 +++++ .../buffer/FileSegmentManagedBuffer.java | 146 ++++++++ .../spark/network/buffer/ManagedBuffer.java | 70 ++++ .../network/buffer/NettyManagedBuffer.java | 76 ++++ .../network/buffer/NioManagedBuffer.java | 75 ++++ .../client/ChunkFetchFailureException.java | 37 ++ .../network/client/ChunkReceivedCallback.java | 47 +++ .../network/client/RpcResponseCallback.java | 30 ++ .../spark/network/client/SluiceClient.java | 161 +++++++++ .../network/client/SluiceClientFactory.java | 173 ++++++++++ .../network/client/SluiceClientHandler.java | 155 +++++++++ .../spark/network/protocol/Encodable.java | 35 ++ .../spark/network/protocol/StreamChunkId.java | 73 ++++ .../protocol/request/ChunkFetchRequest.java | 68 ++++ .../protocol/request/ClientRequest.java | 58 ++++ .../request/ClientRequestDecoder.java | 57 +++ .../request/ClientRequestEncoder.java | 46 +++ .../network/protocol/request/RpcRequest.java | 81 +++++ .../protocol/response/ChunkFetchFailure.java | 78 +++++ .../protocol/response/ChunkFetchSuccess.java | 82 +++++ .../network/protocol/response/RpcFailure.java | 73 ++++ .../protocol/response/RpcResponse.java | 72 ++++ .../protocol/response/ServerResponse.java | 63 ++++ .../response/ServerResponseDecoder.java | 60 ++++ .../response/ServerResponseEncoder.java | 74 ++++ .../network/server/DefaultStreamManager.java | 87 +++++ .../spark/network/server/RpcHandler.java | 31 ++ .../spark/network/server/SluiceServer.java | 124 +++++++ .../network/server/SluiceServerHandler.java | 153 ++++++++ .../spark/network/server/StreamManager.java | 52 +++ .../spark/network/util/ConfigProvider.java | 52 +++ .../network/util/DefaultConfigProvider.java | 32 ++ .../org/apache/spark/network/util/IOMode.java | 27 ++ .../apache/spark/network/util/JavaUtils.java | 19 +- .../apache/spark/network/util/NettyUtils.java | 109 ++++++ .../spark/network/util/SluiceConfig.java | 38 +- .../spark/network/IntegrationSuite.java | 217 ++++++++++++ .../apache/spark/network/NoOpRpcHandler.java | 26 ++ .../apache/spark/network/ProtocolSuite.java | 84 +++++ .../network/SluiceClientFactorySuite.java | 101 ++++++ .../network/SluiceClientHandlerSuite.java | 90 +++++ .../spark/network/TestManagedBuffer.java | 104 ++++++ .../org/apache/spark/network/TestUtils.java | 30 ++ pom.xml | 1 + project/MimaExcludes.scala | 2 - .../streaming/scheduler/ReceiverTracker.scala | 2 +- 76 files changed, 3579 insertions(+), 1899 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/protocol.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala create mode 100644 network/common/pom.xml create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java create mode 100644 network/common/src/main/java/org/apache/spark/network/server/StreamManager.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/IOMode.java rename core/src/main/scala/org/apache/spark/network/exceptions.scala => network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java (65%) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java rename core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala => network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java (58%) create mode 100644 network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java create mode 100644 network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java create mode 100644 network/common/src/test/java/org/apache/spark/network/TestUtils.java diff --git a/core/pom.xml b/core/pom.xml index a5a178079bc57..aff0d989d01bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -44,6 +44,11 @@ + + org.apache.spark + network + ${project.version} + net.java.dev.jets3t jets3t diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 373ce795a309e..867173e04714e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.netty.{NettyBlockTransferService} import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -40,7 +40,6 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} - /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -233,12 +232,14 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. - val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new NettyBlockTransferService(conf) - } else { - new NioBlockTransferService(conf, securityManager) - } + // TODO: This is only netty by default for initial testing -- it should not be merged as such!!! + val blockTransferService = + conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { + case "netty" => + new NettyBlockTransferService(conf) + case "nio" => + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 0eeffe0e7c5e6..1745d52c81923 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,8 +17,8 @@ package org.apache.spark.network -import org.apache.spark.storage.StorageLevel - +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] trait BlockDataManager { @@ -27,10 +27,10 @@ trait BlockDataManager { * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - def getBlockData(blockId: String): ManagedBuffer + def getBlockData(blockId: BlockId): ManagedBuffer /** * Put the block locally, using the given storage level. */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index dd70e26647939..e35fdb4e95899 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -19,6 +19,8 @@ package org.apache.spark.network import java.util.EventListener +import org.apache.spark.network.buffer.ManagedBuffer + /** * Listener callback interface for [[BlockTransferService.fetchBlocks]]. diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index d3ed683c7e880..8287a0fc81cfe 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,16 +18,18 @@ package org.apache.spark.network import java.io.Closeable -import java.nio.ByteBuffer + +import org.apache.spark.network.buffer.ManagedBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration +import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.Utils private[spark] -abstract class BlockTransferService extends Closeable { +abstract class BlockTransferService extends Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -92,10 +94,7 @@ abstract class BlockTransferService extends Closeable { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result = Left(new NioManagedBuffer(ret)) + result = Left(data) lock.notify() } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala deleted file mode 100644 index dd808d2500fbc..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.io._ -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode - -import scala.util.Try - -import com.google.common.io.ByteStreams -import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf} -import io.netty.channel.DefaultFileRegion - -import org.apache.spark.util.{ByteBufferInputStream, Utils} - - -/** - * This interface provides an immutable view for data in the form of bytes. The implementation - * should specify how the data is provided: - * - * - [[FileSegmentManagedBuffer]]: data backed by part of a file - * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer - * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf - * - * The concrete buffer implementation might be managed outside the JVM garbage collector. - * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted. - * In that case, if the buffer is going to be passed around to a different thread, retain/release - * should be called. - */ -private[spark] -abstract class ManagedBuffer { - // Note that all the methods are defined with parenthesis because their implementations can - // have side effects (io operations). - - /** Number of bytes of the data. */ - def size: Long - - /** - * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the - * returned ByteBuffer should not affect the content of this buffer. - */ - def nioByteBuffer(): ByteBuffer - - /** - * Exposes this buffer's data as an InputStream. The underlying implementation does not - * necessarily check for the length of bytes read, so the caller is responsible for making sure - * it does not go over the limit. - */ - def inputStream(): InputStream - - /** - * Increment the reference count by one if applicable. - */ - def retain(): this.type - - /** - * If applicable, decrement the reference count by one and deallocates the buffer if the - * reference count reaches zero. - */ - def release(): this.type - - /** - * Convert the buffer into an Netty object, used to write the data out. - */ - private[network] def convertToNetty(): AnyRef -} - - -/** - * A [[ManagedBuffer]] backed by a segment in a file - */ -private[spark] -final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) - extends ManagedBuffer { - - override def size: Long = length - - override def nioByteBuffer(): ByteBuffer = { - var channel: FileChannel = null - try { - channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) - } catch { - case e: IOException => - Try(channel.size).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - } finally { - if (channel != null) { - Utils.tryLog(channel.close()) - } - } - } - - override def inputStream(): InputStream = { - var is: FileInputStream = null - try { - is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) - } catch { - case e: IOException => - if (is != null) { - Utils.tryLog(is.close()) - } - Try(file.length).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - case e: Throwable => - if (is != null) { - Utils.tryLog(is.close()) - } - throw e - } - } - - override def toString: String = s"${getClass.getName}($file, $offset, $length)" -} - - -/** - * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. - */ -private[spark] -final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { - - override def size: Long = buf.remaining() - - override def nioByteBuffer() = buf.duplicate() - - override def inputStream() = new ByteBufferInputStream(buf) - - private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) - - // [[ByteBuffer]] is managed by the JVM garbage collector itself. - override def retain(): this.type = this - override def release(): this.type = this -} - - -/** - * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. - */ -private[spark] -final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { - - override def size: Long = buf.readableBytes() - - override def nioByteBuffer() = buf.nioBuffer() - - override def inputStream() = new ByteBufInputStream(buf) - - private[network] override def convertToNetty(): AnyRef = buf - - override def retain(): this.type = { - buf.retain() - this - } - - override def release(): this.type = { - buf.release() - this - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala deleted file mode 100644 index 6bdbf88d337ce..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.util.concurrent.TimeoutException - -import scala.concurrent.{Future, promise} - -import io.netty.channel.{ChannelFuture, ChannelFutureListener} - -import org.apache.spark.Logging -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener} -import org.apache.spark.storage.StorageLevel - - -/** - * Client for [[NettyBlockTransferService]]. The connection to server must have been established - * using [[BlockClientFactory]] before instantiating this. - * - * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible - * for handling responses from the server. - * - * Concurrency: thread safe and can be called from multiple threads. - * - * @param cf the ChannelFuture for the connection. - * @param handler [[BlockClientHandler]] for handling outstanding requests. - */ -@throws[TimeoutException] -private[netty] -class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging { - - private[this] val serverAddr = cf.channel().remoteAddress().toString - - def isActive: Boolean = cf.channel().isActive - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. - * - * @param blockIds sequence of block ids to fetch. - * @param listener callback to fire on fetch success / failure. - */ - def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - var startTime: Long = 0 - logTrace { - startTime = System.currentTimeMillis() - s"Sending request $blockIds to $serverAddr" - } - - blockIds.foreach { blockId => - handler.addFetchRequest(blockId, listener) - } - - cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = System.currentTimeMillis() - startTime - s"Sending request $blockIds to $serverAddr took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - blockIds.foreach { blockId => - handler.removeFetchRequest(blockId) - listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) - } - } - } - }) - } - - def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = - { - var startTime: Long = 0 - logTrace { - startTime = System.currentTimeMillis() - s"Uploading block ($blockId) to $serverAddr" - } - val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel)) - - val p = promise[Unit]() - handler.addUploadRequest(blockId, p) - f.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = System.currentTimeMillis() - startTime - s"Uploading block ($blockId) to $serverAddr took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - } - } - }) - - p.future - } - - /** Close the connection. This does NOT block till the connection is closed. */ - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala deleted file mode 100644 index 8021cfdf42d1a..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.util.concurrent.{ConcurrentHashMap, TimeoutException} - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel._ -import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.util.internal.PlatformDependent - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.util.Utils - - -/** - * Factory for creating [[BlockClient]] by using createClient. - * - * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]] - * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. - */ -private[netty] -class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client") - - /** Socket channel type, initialized by [[init]] depending ioMode. */ - private[this] var socketChannelClass: Class[_ <: Channel] = _ - - /** Thread pool shared by all clients. */ - private[this] var workerGroup: EventLoopGroup = _ - - private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient] - - // The encoders are stateless and can be shared among multiple clients. - private[this] val encoder = new ClientRequestEncoder - private[this] val decoder = new ServerResponseDecoder - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory) - } - - // For auto mode, first try epoll (only available on Linux), then nio. - conf.ioMode match { - case "nio" => initNio() - case "epoll" => initEpoll() - case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockClient = { - // Get connection from the connection pool first. - // If it is not found or not active, create a new one. - val cachedClient = connectionPool.get((remoteHost, remotePort)) - if (cachedClient != null && cachedClient.isActive) { - return cachedClient - } - - logDebug(s"Creating new connection to $remoteHost:$remotePort") - - // There is a chance two threads are creating two different clients connecting to the same host. - // But that's probably ok ... - - val handler = new BlockClientHandler - - val bootstrap = new Bootstrap - bootstrap.group(workerGroup) - .channel(socketChannelClass) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()) - - bootstrap.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler) - } - }) - - // Connect to the remote server - val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) - if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") - } - - val client = new BlockClient(cf, handler) - connectionPool.put((remoteHost, remotePort), client) - client - } - - /** Close all connections in the connection pool, and shutdown the worker thread pool. */ - override def close(): Unit = { - val iter = connectionPool.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.close() - connectionPool.remove(entry.getKey) - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } - - /** - * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled because the ByteBufs are allocated by the event loop thread, but released by the - * executor thread rather than the event loop thread. Those thread-local caches actually delay - * the recycling of buffers, leading to larger memory usage. - */ - private def createPooledByteBufAllocator(): PooledByteBufAllocator = { - def getPrivateStaticField(name: String): Int = { - val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name) - f.setAccessible(true) - f.getInt(null) - } - new PooledByteBufAllocator( - PlatformDependent.directBufferPreferred(), - getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), - getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - 0, // tinyCacheSize - 0, // smallCacheSize - 0 // normalCacheSize - ) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala deleted file mode 100644 index 5e28a07a461fa..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Promise - -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging -import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener} - - -/** - * Handler that processes server responses, in response to requests issued from [[BlockClient]]. - * It works by tracking the list of outstanding requests (and their callbacks). - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[netty] -class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] = - new ConcurrentHashMap[String, BlockFetchingListener] - - private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] = - new ConcurrentHashMap[String, Promise[Unit]] - - def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = { - outstandingFetches.put(blockId, listener) - } - - def removeFetchRequest(blockId: String): Unit = { - outstandingFetches.remove(blockId) - } - - def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = { - outstandingUploads.put(blockId, promise) - } - - /** - * Fire the failure callback for all outstanding requests. This is called when we have an - * uncaught exception or pre-mature connection termination. - */ - private def failOutstandingRequests(cause: Throwable): Unit = { - val iter1 = outstandingFetches.entrySet().iterator() - while (iter1.hasNext) { - val entry = iter1.next() - entry.getValue.onBlockFetchFailure(entry.getKey, cause) - } - // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests - // as well. But I guess that is ok given the caller will fail as soon as any requests fail. - outstandingFetches.clear() - - val iter2 = outstandingUploads.entrySet().iterator() - while (iter2.hasNext) { - val entry = iter2.next() - entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}")) - } - outstandingUploads.clear() - } - - override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { - if (outstandingFetches.size() > 0) { - logError("Still have " + outstandingFetches.size() + " requests outstanding " + - s"when connection from ${ctx.channel.remoteAddress} is closed") - failOutstandingRequests(new RuntimeException( - s"Connection from ${ctx.channel.remoteAddress} closed")) - } - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - if (outstandingFetches.size() > 0) { - logError( - s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) - failOutstandingRequests(cause) - } - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) { - val server = ctx.channel.remoteAddress.toString - response match { - case BlockFetchSuccess(blockId, buf) => - val listener = outstandingFetches.get(blockId) - if (listener == null) { - logWarning(s"Got a response for block $blockId from $server but it is not outstanding") - buf.release() - } else { - outstandingFetches.remove(blockId) - listener.onBlockFetchSuccess(blockId, buf) - buf.release() - } - case BlockFetchFailure(blockId, errorMsg) => - val listener = outstandingFetches.get(blockId) - if (listener == null) { - logWarning( - s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") - } else { - outstandingFetches.remove(blockId) - listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg)) - } - case BlockUploadSuccess(blockId) => - val p = outstandingUploads.get(blockId) - if (p == null) { - logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") - } else { - outstandingUploads.remove(blockId) - p.success(Unit) - } - case BlockUploadFailure(blockId, error) => - val p = outstandingUploads.get(blockId) - if (p == null) { - logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") - } else { - outstandingUploads.remove(blockId) - p.failure(new BlockUploadFailureException(blockId)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala deleted file mode 100644 index e2eb7c379f14d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.Closeable -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} - -import org.apache.spark.Logging -import org.apache.spark.network.BlockDataManager -import org.apache.spark.util.Utils - - -/** - * Server for the [[NettyBlockTransferService]]. - */ -private[netty] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) - extends Closeable with Logging { - - def port: Int = _port - - def hostName: String = _hostName - - private var _port: Int = conf.serverPort - private var _hostName: String = "" - private var bootstrap: ServerBootstrap = _ - private var channelFuture: ChannelFuture = _ - - init() - - /** Initialize the server. */ - private def init(): Unit = { - bootstrap = new ServerBootstrap - val threadFactory = Utils.namedThreadFactory("spark-netty-server") - - // Use only one thread to accept connections, and 2 * num_cores for worker. - def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) - } - def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) - } - - conf.ioMode match { - case "nio" => initNio() - case "epoll" => initEpoll() - case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() - } - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - - // Various (advanced) user-configured settings. - conf.backLog.foreach { backLog => - bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) - } - conf.receiveBuf.foreach { receiveBuf => - bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) - } - conf.sendBuf.foreach { sendBuf => - bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) - } - - bootstrap.childHandler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("clientRequestDecoder", new ClientRequestDecoder) - .addLast("serverResponseEncoder", new ServerResponseEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress(_port)) - channelFuture.sync() - - val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] - _port = addr.getPort - // _hostName = addr.getHostName - _hostName = Utils.localHostName() - - logInfo(s"Server started ${_hostName}:${_port}") - } - - /** Shutdown the server. */ - def close(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() - } - bootstrap = null - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala deleted file mode 100644 index 44687f0b770e9..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.network.{ManagedBuffer, BlockDataManager} -import org.apache.spark.storage.StorageLevel - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer. - */ -private[netty] class BlockServerHandler(dataProvider: BlockDataManager) - extends SimpleChannelInboundHandler[ClientRequest] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { - request match { - case BlockFetchRequest(blockIds) => - blockIds.foreach(processFetchRequest(ctx, _)) - case BlockUploadRequest(blockId, data, level) => - processUploadRequest(ctx, blockId, data, level) - } - } // end of channelRead0 - - private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { - // A helper function to send error message back to the client. - def client = ctx.channel.remoteAddress.toString - - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - // First make sure we can find the block. If not, send error back to the user. - var buf: ManagedBuffer = null - try { - buf = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.size} B) back to $client") - } else { - logError( - s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - } - ) - } // end of processBlockRequest - - private def processUploadRequest( - ctx: ChannelHandlerContext, - blockId: String, - data: ManagedBuffer, - level: StorageLevel): Unit = { - // A helper function to send error message back to the client. - def client = ctx.channel.remoteAddress.toString - - try { - dataProvider.putBlockData(blockId, data, level) - ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (!future.isSuccess) { - logError(s"Error sending an ACK back to client $client") - } - } - }) - } catch { - case e: Throwable => - logError(s"Error processing uploaded block $blockId", e) - ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (!future.isSuccess) { - logError(s"Error sending an ACK back to client $client") - } - } - }) - } - } // end of processUploadRequest -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala new file mode 100644 index 0000000000000..aefd8a6335b2a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer +import java.util + +import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.Utils + +/** + * Responsible for holding the state for a request for a single set of blocks. This assumes that + * the chunks will be returned in the same order as requested, and that there will be exactly + * one chunk per block. + * + * Upon receipt of any block, the listener will be called back. Upon failure part way through, + * the listener will receive a failure callback for each outstanding block. + */ +class NettyBlockFetcher( + serializer: Serializer, + client: SluiceClient, + blockIds: Seq[String], + listener: BlockFetchingListener) + extends Logging { + + require(blockIds.nonEmpty) + + val ser = serializer.newInstance() + + var streamHandle: ShuffleStreamHandle = _ + + val chunkCallback = new ChunkReceivedCallback { + // On receipt of a chunk, pass it upwards as a block. + def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions { + buffer.retain() + listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer) + } + + // On receipt of a failure, fail every block from chunkIndex onwards. + def onFailure(chunkIndex: Int, e: Throwable): Unit = { + blockIds.drop(chunkIndex).foreach { blockId => + listener.onBlockFetchFailure(blockId, e); + } + } + } + + // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle. + client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + try { + streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response)) + logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.") + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (i <- 0 until streamHandle.numChunks) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback) + } + } catch { + case e: Exception => + logError("Failed while starting block fetches", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + } + + override def onFailure(e: Throwable): Unit = { + logError("Failed while starting block fetches") + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + }) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala new file mode 100644 index 0000000000000..c8658ec98b82c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import org.apache.spark.Logging +import org.apache.spark.network.BlockDataManager +import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler} +import org.apache.spark.storage.BlockId + +import scala.collection.JavaConversions._ + +/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ +case class OpenBlocks(blockIds: Seq[BlockId]) + +/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */ +case class ShuffleStreamHandle(streamId: Long, numChunks: Int) + +/** + * Serves requests to open blocks by simply registering one chunk per block requested. + */ +class NettyBlockRpcServer( + serializer: Serializer, + streamManager: DefaultStreamManager, + blockManager: BlockDataManager) + extends RpcHandler with Logging { + + override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { + val ser = serializer.newInstance() + val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + logTrace(s"Received request: $message") + message match { + case OpenBlocks(blockIds) => + val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + val streamId = streamManager.registerStream(blocks.iterator) + responseContext.onSuccess( + ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7f979dccd0f5..7576d51e22175 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,38 +17,39 @@ package org.apache.spark.network.netty -import scala.concurrent.Future - import org.apache.spark.SparkConf import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory} +import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer} +import org.apache.spark.network.util.{ConfigProvider, SluiceConfig} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import scala.concurrent.Future /** - * A [[BlockTransferService]] implementation based on Netty. - * - * See protocol.scala for the communication protocol between server and client + * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -private[spark] -final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { +class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + var client: SluiceClient = _ - private[this] val nettyConf: NettyConfig = new NettyConfig(conf) + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + val serializer = new JavaSerializer(conf) - private[this] var server: BlockServer = _ - private[this] var clientFactory: BlockClientFactory = _ + // Create a SluiceConfig using SparkConf. + private[this] val sluiceConf = new SluiceConfig( + new ConfigProvider { override def get(name: String) = conf.get(name) }) - override def init(blockDataManager: BlockDataManager): Unit = { - server = new BlockServer(nettyConf, blockDataManager) - clientFactory = new BlockClientFactory(nettyConf) - } + private[this] var server: SluiceServer = _ + private[this] var clientFactory: SluiceClientFactory = _ - override def close(): Unit = { - if (server != null) { - server.close() - } - if (clientFactory != null) { - clientFactory.close() - } + override def init(blockDataManager: BlockDataManager): Unit = { + val streamManager = new DefaultStreamManager + val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager) + server = new SluiceServer(sluiceConf, streamManager, rpcHandler) + clientFactory = new SluiceClientFactory(sluiceConf) } override def fetchBlocks( @@ -56,29 +57,21 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { - clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener) + val client = clientFactory.createClient(hostName, port) + new NettyBlockFetcher(serializer, client, blockIds, listener) } + override def hostName: String = Utils.localHostName() + + override def port: Int = server.getPort + + // TODO: Implement override def uploadBlock( hostname: String, port: Int, blockId: String, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] = { - clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level) - } + level: StorageLevel): Future[Unit] = ??? - override def hostName: String = { - if (server == null) { - throw new IllegalStateException("Server has not been started") - } - server.hostName - } - - override def port: Int = { - if (server == null) { - throw new IllegalStateException("Server has not been started") - } - server.port - } + override def close(): Unit = server.close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala deleted file mode 100644 index 13942f3d0adcd..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer -import java.util.{List => JList} - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.channel.ChannelHandler.Sharable -import io.netty.handler.codec._ - -import org.apache.spark.Logging -import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer} -import org.apache.spark.storage.StorageLevel - - -/** Messages from the client to the server. */ -private[netty] -sealed trait ClientRequest { - def id: Byte -} - -/** - * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can - * correspond to multiple [[ServerResponse]]s. - */ -private[netty] -final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { - override def id = 0 -} - -/** - * Request to upload a block to the server. Currently the server does not ack the upload request. - */ -private[netty] -final case class BlockUploadRequest( - blockId: String, - data: ManagedBuffer, - level: StorageLevel) - extends ClientRequest { - require(blockId.length <= Byte.MaxValue) - override def id = 1 -} - - -/** Messages from server to client (usually in response to some [[ClientRequest]]. */ -private[netty] -sealed trait ServerResponse { - def id: Byte -} - -/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ -private[netty] -final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 0 -} - -/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ -private[netty] -final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 1 -} - -/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */ -private[netty] -final case class BlockUploadSuccess(blockId: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 2 -} - -/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */ -private[netty] -final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse { - require(blockId.length <= Byte.MaxValue) - override def id = 3 -} - - -/** - * Encoder for [[ClientRequest]] used in client side. - * - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@Sharable -private[netty] -final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { - override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { - in match { - case BlockFetchRequest(blocks) => - // 8 bytes: frame size - // 1 byte: BlockFetchRequest vs BlockUploadRequest - // 4 byte: num blocks - // then for each block id write 1 byte for blockId.length and then blockId itself - val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _) - val buf = ctx.alloc().buffer(frameLength) - - buf.writeLong(frameLength) - buf.writeByte(in.id) - buf.writeInt(blocks.size) - blocks.foreach { blockId => - ProtocolUtils.writeBlockId(buf, blockId) - } - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadRequest(blockId, data, level) => - // 8 bytes: frame size - // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) - // 1 byte: blockId.length - // data itself (length can be derived from: frame size - 1 - blockId.length) - val headerLength = 8 + 1 + 1 + blockId.length + 5 - val frameLength = headerLength + data.size - val header = ctx.alloc().buffer(headerLength) - - // Call this before we add header to out so in case of exceptions - // we don't send anything at all. - val body = data.convertToNetty() - - header.writeLong(frameLength) - header.writeByte(in.id) - ProtocolUtils.writeBlockId(header, blockId) - header.writeInt(level.toInt) - header.writeByte(level.replication) - - assert(header.writableBytes() == 0) - out.add(header) - out.add(body) - } - } -} - - -/** - * Decoder in the server side to decode client requests. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * [[ProtocolUtils.createFrameDecoder()]]. - */ -@Sharable -private[netty] -final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { - override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = - { - val msgTypeId = in.readByte() - val decoded = msgTypeId match { - case 0 => // BlockFetchRequest - val numBlocks = in.readInt() - val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) } - BlockFetchRequest(blockIds) - - case 1 => // BlockUploadRequest - val blockId = ProtocolUtils.readBlockId(in) - val level = new StorageLevel(in.readInt(), in.readByte()) - - val ret = ByteBuffer.allocate(in.readableBytes()) - ret.put(in.nioBuffer()) - ret.flip() - BlockUploadRequest(blockId, new NioManagedBuffer(ret), level) - } - - assert(decoded.id == msgTypeId) - out.add(decoded) - } -} - - -/** - * Encoder used by the server side to encode server-to-client responses. - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@Sharable -private[netty] -final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { - override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { - in match { - case BlockFetchSuccess(blockId, data) => - // Handle the body first so if we encounter an error getting the body, we can respond - // with an error instead. - var body: AnyRef = null - try { - body = data.convertToNetty() - } catch { - case e: Exception => - // Re-encode this message as BlockFetchFailure. - logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e) - encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out) - return - } - - // If we got here, body cannot be null - // 8 bytes = long for frame length - // 1 byte = message id (type) - // 1 byte = block id length - // followed by block id itself - val headerLength = 8 + 1 + 1 + blockId.length - val frameLength = headerLength + data.size - val header = ctx.alloc().buffer(headerLength) - header.writeLong(frameLength) - header.writeByte(in.id) - ProtocolUtils.writeBlockId(header, blockId) - - assert(header.writableBytes() == 0) - out.add(header) - out.add(body) - - case BlockFetchFailure(blockId, error) => - val frameLength = 8 + 1 + 1 + blockId.length + error.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - buf.writeBytes(error.getBytes) - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadSuccess(blockId) => - val frameLength = 8 + 1 + 1 + blockId.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - - assert(buf.writableBytes() == 0) - out.add(buf) - - case BlockUploadFailure(blockId, error) => - val frameLength = 8 + 1 + 1 + blockId.length + + error.length - val buf = ctx.alloc().buffer(frameLength) - buf.writeLong(frameLength) - buf.writeByte(in.id) - ProtocolUtils.writeBlockId(buf, blockId) - buf.writeBytes(error.getBytes) - - assert(buf.writableBytes() == 0) - out.add(buf) - } - } -} - - -/** - * Decoder in the client side to decode server responses. - * This decoder is stateless so it is safe to be shared by multiple threads. - * - * This assumes the inbound messages have been processed by a frame decoder created by - * [[ProtocolUtils.createFrameDecoder()]]. - */ -@Sharable -private[netty] -final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { - override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { - val msgId = in.readByte() - val decoded = msgId match { - case 0 => // BlockFetchSuccess - val blockId = ProtocolUtils.readBlockId(in) - in.retain() - BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) - - case 1 => // BlockFetchFailure - val blockId = ProtocolUtils.readBlockId(in) - val errorBytes = new Array[Byte](in.readableBytes()) - in.readBytes(errorBytes) - BlockFetchFailure(blockId, new String(errorBytes)) - - case 2 => // BlockUploadSuccess - BlockUploadSuccess(ProtocolUtils.readBlockId(in)) - - case 3 => // BlockUploadFailure - val blockId = ProtocolUtils.readBlockId(in) - val errorBytes = new Array[Byte](in.readableBytes()) - in.readBytes(errorBytes) - BlockUploadFailure(blockId, new String(errorBytes)) - } - - assert(decoded.id == msgId) - out.add(decoded) - } -} - - -private[netty] object ProtocolUtils { - - /** LengthFieldBasedFrameDecoder used before all decoders. */ - def createFrameDecoder(): ByteToMessageDecoder = { - // maxFrameLength = 2G - // lengthFieldOffset = 0 - // lengthFieldLength = 8 - // lengthAdjustment = -8, i.e. exclude the 8 byte length itself - // initialBytesToStrip = 8, i.e. strip out the length field itself - new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) - } - - // TODO(rxin): Make sure these work for all charsets. - def readBlockId(in: ByteBuf): String = { - val numBytesToRead = in.readByte().toInt - val bytes = new Array[Byte](numBytesToRead) - in.readBytes(bytes) - new String(bytes) - } - - def writeBlockId(out: ByteBuf, blockId: String): Unit = { - out.writeByte(blockId.length) - out.writeBytes(blockId.getBytes) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index e942b43d9cc4a..bce1069548437 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -19,12 +19,13 @@ package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.concurrent.Future - -import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} + +import scala.concurrent.Future /** @@ -153,12 +154,11 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => { + case e: Exception => logError("Exception handling buffer message", e) val errorMessage = Message.createBufferMessage(msg.id) errorMessage.hasError = true Some(errorMessage) - } } case otherMessage: Any => @@ -174,13 +174,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa case BlockMessage.TYPE_PUT_BLOCK => val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + msg + "]") - putBlock(msg.id.toString, msg.data, msg.level) + putBlock(msg.id, msg.data, msg.level) None case BlockMessage.TYPE_GET_BLOCK => val msg = new GetBlock(blockMessage.getId) logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id.toString) + val buffer = getBlock(msg.id) if (buffer == null) { return None } @@ -190,7 +190,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } } - private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) @@ -198,7 +198,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa + " with data size: " + bytes.limit) } - private def getBlock(blockId: String): ByteBuffer = { + private def getBlock(blockId: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) val buffer = blockDataManager.getBlockData(blockId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 439981d232349..c35aa2481ad03 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import org.apache.spark.{SparkEnv, SparkConf, Logging} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.{Logging, SparkConf, SparkEnv} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 4ab34336d3f01..6a9fa4ec65d5d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -21,7 +21,7 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.SparkEnv -import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.storage._ /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index 63863cc0250a3..b521f0c7fc77e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer - -import org.apache.spark.network.ManagedBuffer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ac0599f30ef22..4d8b5c1e1b084 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,15 +17,13 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} +import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.concurrent.ExecutionContext.Implicits.global - -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} import scala.util.Random import akka.actor.{ActorSystem, Props} @@ -35,11 +33,11 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ - private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -215,17 +213,17 @@ private[spark] class BlockManager( * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - override def getBlockData(blockId: String): ManagedBuffer = { - val bid = BlockId(blockId) - if (bid.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) + override def getBlockData(blockId: BlockId): ManagedBuffer = { + if (blockId.isShuffle) { + shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { - val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) + .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get new NioManagedBuffer(buffer) } else { - throw new BlockNotFoundException(blockId) + throw new BlockNotFoundException(blockId.toString) } } } @@ -233,8 +231,8 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(BlockId(blockId), data.nioByteBuffer(), level) + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(blockId, data.nioByteBuffer(), level) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala index 9ef453605f4f1..81f5f2d31dbd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -17,5 +17,4 @@ package org.apache.spark.storage - class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d095452a261db..23313fe9271fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,14 +19,13 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue +import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.{Logging, TaskContext} /** @@ -228,7 +227,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val blockId = iter.next() try { - val buf = blockManager.getBlockData(blockId.toString) + val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 buf.retain() results.put(new FetchResult(blockId, 0, buf)) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 2fc7c7d9b8312..1e35abaab5353 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -42,7 +42,7 @@ class StorageLevel private( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - private[spark] def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -98,7 +98,6 @@ class StorageLevel private( } override def writeExternal(out: ObjectOutput) { - /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */ out.writeByte(toInt) out.writeByte(_replication) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala deleted file mode 100644 index 2d4baafcf03d0..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import scala.concurrent.{Await, future} -import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global - -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.SparkConf - - -class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { - - private val conf = new SparkConf - private var server1: BlockServer = _ - private var server2: BlockServer = _ - - override def beforeAll() { - server1 = new BlockServer(new NettyConfig(conf), null) - server2 = new BlockServer(new NettyConfig(conf), null) - } - - override def afterAll() { - if (server1 != null) { - server1.close() - } - if (server2 != null) { - server2.close() - } - } - - test("BlockClients created are active and reused") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - val c2 = factory.createClient(server1.hostName, server1.port) - val c3 = factory.createClient(server2.hostName, server2.port) - assert(c1.isActive) - assert(c3.isActive) - assert(c1 === c2) - assert(c1 !== c3) - factory.close() - } - - test("never return inactive clients") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - c1.close() - - // Block until c1 is no longer active - val f = future { - while (c1.isActive) { - Thread.sleep(10) - } - } - Await.result(f, 3.seconds) - assert(!c1.isActive) - - // Create c2, which should be different from c1 - val c2 = factory.createClient(server1.hostName, server1.port) - assert(c1 !== c2) - factory.close() - } - - test("BlockClients are close when BlockClientFactory is stopped") { - val factory = new BlockClientFactory(conf) - val c1 = factory.createClient(server1.hostName, server1.port) - val c2 = factory.createClient(server2.hostName, server2.port) - assert(c1.isActive) - assert(c2.isActive) - factory.close() - assert(!c1.isActive) - assert(!c2.isActive) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala deleted file mode 100644 index 4c3a649081574..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.mockito.Mockito._ -import org.mockito.Matchers.{any, eq => meq} - -import org.scalatest.{FunSuite, PrivateMethodTester} - -import org.apache.spark.network._ - - -class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { - - /** Helper method to get num. outstanding requests from a private field using reflection. */ - private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val f = handler.getClass.getDeclaredField( - "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches") - f.setAccessible(true) - f.get(handler).asInstanceOf[java.util.Map[_, _]].size - } - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest(blockId, listener) - assert(sizeOfOutstandingRequests(handler) === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) - verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest(blockId, listener) - assert(sizeOfOutstandingRequests(handler) === 1) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchFailure(blockId, "some error msg")) - verify(listener, times(0)).onBlockFetchSuccess(any(), any()) - verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("clear all outstanding request upon uncaught exception") { - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest("b1", listener) - handler.addFetchRequest("b2", listener) - handler.addFetchRequest("b3", listener) - assert(sizeOfOutstandingRequests(handler) === 3) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) - channel.pipeline().fireExceptionCaught(new Exception("duh duh duh")) - - // should fail both b2 and b3 - verify(listener, times(1)).onBlockFetchSuccess(any(), any()) - verify(listener, times(2)).onBlockFetchFailure(any(), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } - - test("clear all outstanding request upon connection close") { - val handler = new BlockClientHandler - val listener = mock(classOf[BlockFetchingListener]) - handler.addFetchRequest("c1", listener) - handler.addFetchRequest("c2", listener) - handler.addFetchRequest("c3", listener) - assert(sizeOfOutstandingRequests(handler) === 3) - - val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer()))) - channel.finish() - - // should fail both b2 and b3 - verify(listener, times(1)).onBlockFetchSuccess(any(), any()) - verify(listener, times(2)).onBlockFetchFailure(any(), any()) - assert(sizeOfOutstandingRequests(handler) === 0) - assert(channel.finish() === false) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala deleted file mode 100644 index 8d1b7276f4082..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.api.java.StorageLevels - - -/** - * Test client/server encoder/decoder protocol. - */ -class ProtocolSuite extends FunSuite { - - /** - * Helper to test server to client message protocol by encoding a message and decoding it. - */ - private def testServerToClient(msg: ServerResponse) { - val serverChannel = new EmbeddedChannel(new ServerResponseEncoder) - serverChannel.writeOutbound(msg) - - val clientChannel = new EmbeddedChannel( - ProtocolUtils.createFrameDecoder(), - new ServerResponseDecoder) - - // Drain all server outbound messages and write them to the client's server decoder. - while (!serverChannel.outboundMessages().isEmpty) { - clientChannel.writeInbound(serverChannel.readOutbound()) - } - - assert(clientChannel.inboundMessages().size === 1) - // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is - // overridden. - assert(msg === clientChannel.readInbound()) - } - - /** - * Helper to test client to server message protocol by encoding a message and decoding it. - */ - private def testClientToServer(msg: ClientRequest) { - val clientChannel = new EmbeddedChannel(new ClientRequestEncoder) - clientChannel.writeOutbound(msg) - - val serverChannel = new EmbeddedChannel( - ProtocolUtils.createFrameDecoder(), - new ClientRequestDecoder) - - // Drain all client outbound messages and write them to the server's decoder. - while (!clientChannel.outboundMessages().isEmpty) { - serverChannel.writeInbound(clientChannel.readOutbound()) - } - - assert(serverChannel.inboundMessages().size === 1) - // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is - // overridden. - assert(msg === serverChannel.readInbound()) - } - - test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") { - testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) - } - - test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") { - testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) - } - - test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") { - testServerToClient(BlockFetchFailure("abcd", "this is an error")) - } - - test("server to client protocol - BlockFetchFailure(\"\", \"\")") { - testServerToClient(BlockFetchFailure("", "")) - } - - test("client to server protocol - BlockFetchRequest(Seq.empty[String])") { - testClientToServer(BlockFetchRequest(Seq.empty[String])) - } - - test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") { - testClientToServer(BlockFetchRequest(Seq("b1"))) - } - - test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") { - testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) - } - - test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { - testClientToServer( - BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK)) - } - - test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { - testClientToServer( - BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2)) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala deleted file mode 100644 index 35ff90a2dabc5..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer -import java.util.{Collections, HashSet} -import java.util.concurrent.{TimeUnit, Semaphore} - -import scala.collection.JavaConversions._ - -import io.netty.buffer.Unpooled - -import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.Span -import org.scalatest.time.Seconds - -import org.apache.spark.SparkConf -import org.apache.spark.network._ -import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} - - -/** -* Test cases that create real clients and servers and connect. -*/ -class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { - - val bufSize = 100000 - var buf: ByteBuffer = _ - var testFile: File = _ - var server: BlockServer = _ - var clientFactory: BlockClientFactory = _ - - val bufferBlockId = "buffer_block" - val fileBlockId = "file_block" - - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - - override def beforeAll() = { - buf = ByteBuffer.allocate(bufSize) - for (i <- 1 to bufSize) { - buf.put(i.toByte) - } - buf.flip() - - testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { - override def getBlockData(blockId: String): ManagedBuffer = { - if (blockId == bufferBlockId) { - new NioManagedBuffer(buf) - } else if (blockId == fileBlockId) { - new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) - } else { - throw new BlockNotFoundException(blockId) - } - } - - /** - * Put the block locally, using the given storage level. - */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ??? - }) - - clientFactory = new BlockClientFactory(new SparkConf) - } - - override def afterAll() = { - server.close() - clientFactory.close() - } - - /** A ByteBuf for buffer_block */ - lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) - - /** A ByteBuf for file_block */ - lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { - val client = clientFactory.createClient(server.hostName, server.port) - val sem = new Semaphore(0) - val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) - val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer]) - - client.fetchBlocks( - blockIds, - new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { - errorBlockIds.add(blockId) - sem.release() - } - - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - data.retain() - receivedBlockIds.add(blockId) - receivedBuffers.add(data) - sem.release() - } - } - ) - if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server") - } - client.close() - (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) - } - - test("fetch a ByteBuffer block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a FileSegment block via zero-copy send") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) - assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) - assert(blockIds.isEmpty) - assert(buffers.isEmpty) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and FileSegment block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) - assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } - - test("shutting down server should also close client") { - val client = clientFactory.createClient(server.hostName, server.port) - server.close() - eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala deleted file mode 100644 index e47e4d03fa898..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled - -import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} - - -/** - * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). - * - * Used for testing. - */ -class TestManagedBuffer(len: Int) extends ManagedBuffer { - - require(len <= Byte.MaxValue) - - private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) - - private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)) - - override def size: Long = underlying.size - - override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty() - - override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer() - - override def inputStream(): InputStream = underlying.inputStream() - - override def toString: String = s"${getClass.getName}($len)" - - override def equals(other: Any): Boolean = other match { - case otherBuf: ManagedBuffer => - val nioBuf = otherBuf.nioByteBuffer() - if (nioBuf.remaining() != len) { - return false - } else { - var i = 0 - while (i < len) { - if (nioBuf.get() != i) { - return false - } - i += 1 - } - return true - } - case _ => false - } - - override def retain(): this.type = this - - override def release(): this.type = this -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index ba47fe5e25b9b..6790388f96603 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockManager import org.apache.spark.storage.{ShuffleBlockId, FileSegment} @@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath) - assert(expected.offset === segment.offset) - assert(expected.length === segment.length) + assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) + assert(expected.offset === segment.getOffset) + assert(expected.length === segment.getLength) } test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 7d4086313fcc1..3beb503b206f2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.TestSerializer @@ -71,7 +72,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } // Make sure remote blocks would return diff --git a/network/common/pom.xml b/network/common/pom.xml new file mode 100644 index 0000000000000..e3b7e328701b4 --- /dev/null +++ b/network/common/pom.xml @@ -0,0 +1,94 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + network + jar + Shuffle Streaming Service + http://spark.apache.org/ + + network + + + + + + io.netty + netty-all + + + org.slf4j + slf4j-api + + + + + com.google.guava + guava + provided + + + + + junit + junit + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + + + + target/java/classes + target/java/test-classes + + + org.apache.maven.plugins + maven-surefire-plugin + 2.17 + + false + + **/Test*.java + **/*Test.java + **/*Suite.java + + + + + + diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..224f1e6c515ea --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import com.google.common.base.Objects; +import com.google.common.io.ByteStreams; +import io.netty.channel.DefaultFileRegion; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A {@link ManagedBuffer} backed by a segment in a file. + */ +public final class FileSegmentManagedBuffer extends ManagedBuffer { + + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + + private final File file; + private final long offset; + private final long length; + + public FileSegmentManagedBuffer(File file, long offset, long length) { + this.file = file; + this.offset = offset; + this.length = length; + } + + @Override + public long size() { + return length; + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + FileChannel channel = null; + try { + channel = new RandomAccessFile(file, "r").getChannel(); + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + ByteBuffer buf = ByteBuffer.allocate((int) length); + channel.read(buf, offset); + buf.flip(); + return buf; + } else { + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + } + } catch (IOException e) { + try { + if (channel != null) { + long size = channel.size(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } + throw new IOException("Error in opening " + this, e); + } finally { + JavaUtils.closeQuietly(channel); + } + } + + @Override + public InputStream inputStream() throws IOException { + FileInputStream is = null; + try { + is = new FileInputStream(file); + is.skip(offset); + return ByteStreams.limit(is, length); + } catch (IOException e) { + try { + if (is != null) { + long size = file.length(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } finally { + JavaUtils.closeQuietly(is); + } + throw new IOException("Error in opening " + this, e); + } catch (RuntimeException e) { + JavaUtils.closeQuietly(is); + throw e; + } + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } + + public File getFile() { return file; } + + public long getOffset() { return offset; } + + public long getLength() { return length; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("offset", offset) + .add("length", length) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java new file mode 100644 index 0000000000000..1735f5540c61b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - {@link FileSegmentManagedBuffer}: data backed by part of a file + * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer + * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. + */ +public abstract class ManagedBuffer { + + /** Number of bytes of the data. */ + public abstract long size(); + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + public abstract ByteBuffer nioByteBuffer() throws IOException; + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + public abstract InputStream inputStream() throws IOException; + + /** + * Increment the reference count by one if applicable. + */ + public abstract ManagedBuffer retain(); + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + public abstract ManagedBuffer release(); + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + public abstract Object convertToNetty() throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java new file mode 100644 index 0000000000000..d928980423f1f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; + +/** + * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. + */ +public final class NettyManagedBuffer extends ManagedBuffer { + private final ByteBuf buf; + + public NettyManagedBuffer(ByteBuf buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.readableBytes(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.nioBuffer(); + } + + @Override + public InputStream inputStream() throws IOException { + return new ByteBufInputStream(buf); + } + + @Override + public ManagedBuffer retain() { + buf.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + buf.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return buf; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java new file mode 100644 index 0000000000000..3953ef89fbf88 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; + +/** + * A {@link ManagedBuffer} backed by {@link ByteBuffer}. + */ +public final class NioManagedBuffer extends ManagedBuffer { + private final ByteBuffer buf; + + public NioManagedBuffer(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.remaining(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.duplicate(); + } + + @Override + public InputStream inputStream() throws IOException { + return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} + diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java new file mode 100644 index 0000000000000..40a1fe67b1c5b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * General exception caused by a remote exception while fetching a chunk. + */ +public class ChunkFetchFailureException extends RuntimeException { + private final int chunkIndex; + + public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) { + super(errorMsg, cause); + this.chunkIndex = chunkIndex; + } + + public ChunkFetchFailureException(int chunkIndex, String errorMsg) { + super(errorMsg); + this.chunkIndex = chunkIndex; + } + + public int getChunkIndex() { return chunkIndex; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java new file mode 100644 index 0000000000000..519e6cb470d0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single chunk result. For a single stream, the callbacks are + * guaranteed to be called by the same thread in the same order as the requests for chunks were + * made. + * + * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. + */ +public interface ChunkReceivedCallback { + /** + * Called upon receipt of a particular chunk. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int chunkIndex, ManagedBuffer buffer); + + /** + * Called upon failure to fetch a particular chunk. Note that this may actually be called due + * to failure to fetch a prior chunk in this stream. + * + * After receiving a failure, the stream may or may not be valid. The client should not assume + * that the server's side of the stream has been closed. + */ + void onFailure(int chunkIndex, Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java new file mode 100644 index 0000000000000..6ec960d795420 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * Callback for the result of a single RPC. This will be invoked once with either success or + * failure. + */ +public interface RpcResponseCallback { + /** Successful serialized result from server. */ + void onSuccess(byte[] response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java new file mode 100644 index 0000000000000..1f7d3b0234e38 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.util.UUID; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.RpcRequest; + +/** + * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow + * efficient transfer of a large amount of data, broken up into chunks with size ranging from + * hundreds of KB to a few MB. + * + * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), + * the actual setup of the streams is done outside the scope of Sluice. The convenience method + * "sendRPC" is provided to enable control plane communication between the client and server to + * perform this setup. + * + * For example, a typical workflow might be: + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) + * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) + * ... + * client.sendRPC(new CloseStream(100)) + * + * Construct an instance of SluiceClient using {@link SluiceClientFactory}. A single SluiceClient + * may be used for multiple streams, but any given stream must be restricted to a single client, + * in order to avoid out-of-order responses. + * + * NB: This class is used to make requests to the server, while {@link SluiceClientHandler} is + * responsible for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class SluiceClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceClient.class); + + private final ChannelFuture cf; + private final SluiceClientHandler handler; + + private final String serverAddr; + + SluiceClient(ChannelFuture cf, SluiceClientHandler handler) { + this.cf = cf; + this.handler = handler; + + if (cf != null && cf.channel() != null && cf.channel().remoteAddress() != null) { + serverAddr = cf.channel().remoteAddress().toString(); + } else { + serverAddr = ""; + } + } + + public boolean isActive() { + return cf.channel().isActive(); + } + + /** + * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * + * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though + * some streams may not support this. + * + * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed + * to be returned in the same order that they were requested, assuming only a single SluiceClient + * is used to fetch the chunks. + * + * @param streamId Identifier that refers to a stream in the remote StreamManager. This should + * be agreed upon by client and server beforehand. + * @param chunkIndex 0-based index of the chunk to fetch + * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + */ + public void fetchChunk( + long streamId, + final int chunkIndex, + final ChunkReceivedCallback callback) { + final long startTime = System.currentTimeMillis(); + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + + final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + handler.addFetchRequest(streamChunkId, callback); + + cf.channel().writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + timeTaken); + } else { + // Fail all blocks. + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + serverAddr, future.cause().getMessage()); + logger.error(errorMsg, future.cause()); + future.cause().printStackTrace(); + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg)); + } + } + }); + } + + /** + * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked + * with the server's response or upon any failure. + */ + public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final long startTime = System.currentTimeMillis(); + logger.debug("Sending RPC to {}", serverAddr); + + final long tag = UUID.randomUUID().getLeastSignificantBits(); + handler.addRpcRequest(tag, callback); + + cf.channel().writeAndFlush(new RpcRequest(tag, message)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken); + } else { + // Fail all blocks. + String errorMsg = String.format("Failed to send request %s to %s: %s", tag, + serverAddr, future.cause().getMessage()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(tag); + callback.onFailure(new RuntimeException(errorMsg)); + } + } + }); + } + + @Override + public void close() { + cf.channel().close(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java new file mode 100644 index 0000000000000..17491dc3f8720 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.request.ClientRequestEncoder; +import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Factory for creating {@link SluiceClient}s by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same + * {@link SluiceClient} for the same remote host. It also shares a single worker thread pool for + * all {@link SluiceClient}s. + */ +public class SluiceClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class); + + private final SluiceConfig conf; + private final Map connectionPool; + private final ClientRequestEncoder encoder; + private final ServerResponseDecoder decoder; + + private final Class socketChannelClass; + private final EventLoopGroup workerGroup; + + public SluiceClientFactory(SluiceConfig conf) { + this.conf = conf; + this.connectionPool = new ConcurrentHashMap(); + this.encoder = new ClientRequestEncoder(); + this.decoder = new ServerResponseDecoder(); + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + SluiceClient cachedClient = connectionPool.get(address); + if (cachedClient != null && cachedClient.isActive()) { + return cachedClient; + } + + logger.debug("Creating new connection to " + address); + + // There is a chance two threads are creating two different clients connecting to the same host. + // But that's probably ok, as long as the caller hangs on to their client for a single stream. + final SluiceClientHandler handler = new SluiceClientHandler(); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + + bootstrap.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + ch.pipeline() + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler); + } + }); + + // Connect to the remote server + ChannelFuture cf = bootstrap.connect(address); + if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + throw new TimeoutException( + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } + + SluiceClient client = new SluiceClient(cf, handler); + connectionPool.put(address, client); + return client; + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + @Override + public void close() { + for (SluiceClient client : connectionPool.values()) { + client.close(); + } + connectionPool.clear(); + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + } + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private PooledByteBufAllocator createPooledByteBufAllocator() { + return new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java new file mode 100644 index 0000000000000..ed20b032931c3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.net.SocketAddress; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; +import org.apache.spark.network.protocol.response.ServerResponse; + +/** + * Handler that processes server responses, in response to requests issued from [[SluiceClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class SluiceClientHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceClientHandler.class); + + private final Map outstandingFetches = + new ConcurrentHashMap(); + + private final Map outstandingRpcs = + new ConcurrentHashMap(); + + public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + outstandingFetches.put(streamChunkId, callback); + } + + public void removeFetchRequest(StreamChunkId streamChunkId) { + outstandingFetches.remove(streamChunkId); + } + + public void addRpcRequest(long tag, RpcResponseCallback callback) { + outstandingRpcs.put(tag, callback); + } + + public void removeRpcRequest(long tag) { + outstandingRpcs.remove(tag); + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private void failOutstandingRequests(Throwable cause) { + for (Map.Entry entry : outstandingFetches.entrySet()) { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } + // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests + // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + outstandingFetches.clear(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + if (outstandingFetches.size() > 0) { + SocketAddress remoteAddress = ctx.channel().remoteAddress(); + logger.error("Still have {} requests outstanding when contention from {} is closed", + outstandingFetches.size(), remoteAddress); + failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + } + super.channelUnregistered(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (outstandingFetches.size() > 0) { + logger.error(String.format("Exception in connection from %s: %s", + ctx.channel().remoteAddress(), cause.getMessage()), cause); + failOutstandingRequests(cause); + } + ctx.close(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) { + String server = ctx.channel().remoteAddress().toString(); + if (message instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Got a response for block {} from {} but it is not outstanding", + resp.streamChunkId, server); + resp.buffer.release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); + resp.buffer.release(); + } + } else if (message instanceof ChunkFetchFailure) { + ChunkFetchFailure resp = (ChunkFetchFailure) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Got a response for block {} from {} ({}) but it is not outstanding", + resp.streamChunkId, server, resp.errorString); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onFailure(resp.streamChunkId.chunkIndex, + new ChunkFetchFailureException(resp.streamChunkId.chunkIndex, resp.errorString)); + } + } else if (message instanceof RpcResponse) { + RpcResponse resp = (RpcResponse) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + if (listener == null) { + logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding", + resp.tag, server, resp.response.length); + } else { + outstandingRpcs.remove(resp.tag); + listener.onSuccess(resp.response); + } + } else if (message instanceof RpcFailure) { + RpcFailure resp = (RpcFailure) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.tag); + if (listener == null) { + logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding", + resp.tag, server, resp.errorString); + } else { + outstandingRpcs.remove(resp.tag); + listener.onFailure(new RuntimeException(resp.errorString)); + } + } + } + + @VisibleForTesting + public int numOutstandingRequests() { + return outstandingFetches.size(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java new file mode 100644 index 0000000000000..363ea5ecfa936 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** + * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are + * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + */ +public interface Encodable { + /** Number of bytes of the encoded form of this object. */ + int encodedLength(); + + /** + * Serializes this object by writing into the given ByteBuf. + * This method must write exactly encodedLength() bytes. + */ + void encode(ByteBuf buf); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java new file mode 100644 index 0000000000000..d46a263884807 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** +* Encapsulates a request for a particular chunk of a stream. +*/ +public final class StreamChunkId implements Encodable { + public final long streamId; + public final int chunkIndex; + + public StreamChunkId(long streamId, int chunkIndex) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + public void encode(ByteBuf buffer) { + buffer.writeLong(streamId); + buffer.writeInt(chunkIndex); + } + + public static StreamChunkId decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 8 + 4; + long streamId = buffer.readLong(); + int chunkIndex = buffer.readInt(); + return new StreamChunkId(streamId, chunkIndex); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, chunkIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamChunkId) { + StreamChunkId o = (StreamChunkId) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("chunkIndex", chunkIndex) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java new file mode 100644 index 0000000000000..a79eb363cf58c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single + * {@link org.apache.spark.network.protocol.response.ServerResponse} (either success or failure). + */ +public final class ChunkFetchRequest implements ClientRequest { + public final StreamChunkId streamChunkId; + + public ChunkFetchRequest(StreamChunkId streamChunkId) { + this.streamChunkId = streamChunkId; + } + + @Override + public Type type() { return Type.ChunkFetchRequest; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + public static ChunkFetchRequest decode(ByteBuf buf) { + return new ChunkFetchRequest(StreamChunkId.decode(buf)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchRequest) { + ChunkFetchRequest o = (ChunkFetchRequest) other; + return streamChunkId.equals(o.streamChunkId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java new file mode 100644 index 0000000000000..db075c44b4cda --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** Messages from the client to the server. */ +public interface ClientRequest extends Encodable { + /** Used to identify this request type. */ + Type type(); + + /** + * Preceding every serialized ClientRequest is the type, which allows us to deserialize + * the request. + */ + public static enum Type implements Encodable { + ChunkFetchRequest(0), RpcRequest(1); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 request types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch(id) { + case 0: return ChunkFetchRequest; + case 1: return RpcRequest; + default: throw new IllegalArgumentException("Unknown request type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java new file mode 100644 index 0000000000000..a937da4cecae0 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +/** + * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * {@link org.apache.spark.network.util.NettyUtils#createFrameDecoder()}. + */ +@ChannelHandler.Sharable +public final class ClientRequestDecoder extends MessageToMessageDecoder { + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ClientRequest.Type msgType = ClientRequest.Type.decode(in); + ClientRequest decoded = decode(msgType, in); + assert decoded.type() == msgType; + assert in.readableBytes() == 0; + out.add(decoded); + } + + private ClientRequest decode(ClientRequest.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + + case RpcRequest: + return RpcRequest.decode(in); + + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java new file mode 100644 index 0000000000000..bcff4a0a25568 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestEncoder.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; + +/** + * Encoder for {@link ClientRequest} used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ClientRequestEncoder extends MessageToMessageEncoder { + @Override + public void encode(ChannelHandlerContext ctx, ClientRequest in, List out) { + ClientRequest.Type msgType = in.type(); + // Write 8 bytes for the frame's length, followed by the request type and request itself. + int frameLength = 8 + msgType.encodedLength() + in.encodedLength(); + ByteBuf buf = ctx.alloc().buffer(frameLength); + buf.writeLong(frameLength); + msgType.encode(buf); + in.encode(buf); + assert buf.writableBytes() == 0; + out.add(buf); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java new file mode 100644 index 0000000000000..126370330f723 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/RpcRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.request; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single {@link org.apache.spark.network.protocol.response.ServerResponse} + * (either success or failure). + */ +public final class RpcRequest implements ClientRequest { + /** Tag is used to link an RPC request with its response. */ + public final long tag; + + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public RpcRequest(long tag, byte[] message) { + this.tag = tag; + this.message = message; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + return 8 + 4 + message.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + buf.writeInt(message.length); + buf.writeBytes(message); + } + + public static RpcRequest decode(ByteBuf buf) { + long tag = buf.readLong(); + int messageLen = buf.readInt(); + byte[] message = new byte[messageLen]; + buf.readBytes(message); + return new RpcRequest(tag, message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return tag == o.tag && Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java new file mode 100644 index 0000000000000..3a57d71b4f3ea --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchFailure.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when there is an + * error fetching the chunk. + */ +public final class ChunkFetchFailure implements ServerResponse { + public final StreamChunkId streamChunkId; + public final String errorString; + + public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { + this.streamChunkId = streamChunkId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.ChunkFetchFailure; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 4 + errorString.getBytes().length; + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + byte[] errorBytes = errorString.getBytes(); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static ChunkFetchFailure decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new ChunkFetchFailure(streamChunkId, new String(errorBytes)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchFailure) { + ChunkFetchFailure o = (ChunkFetchFailure) other; + return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java new file mode 100644 index 0000000000000..874dc4f5940cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ChunkFetchSuccess.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.protocol.StreamChunkId; + +/** + * Response to {@link org.apache.spark.network.protocol.request.ChunkFetchRequest} when a chunk + * exists and has been successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class ChunkFetchSuccess implements ServerResponse { + public final StreamChunkId streamChunkId; + public final ManagedBuffer buffer; + + public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + this.streamChunkId = streamChunkId; + this.buffer = buffer; + } + + @Override + public Type type() { return Type.ChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + /** Encoding does NOT include buffer itself. See {@link ServerResponseEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static ChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new ChunkFetchSuccess(streamChunkId, managedBuf); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchSuccess) { + ChunkFetchSuccess o = (ChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("buffer", buffer) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java new file mode 100644 index 0000000000000..274920b28bced --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcFailure.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a failed RPC. */ +public final class RpcFailure implements ServerResponse { + public final long tag; + public final String errorString; + + public RpcFailure(long tag, String errorString) { + this.tag = tag; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.RpcFailure; } + + @Override + public int encodedLength() { + return 8 + 4 + errorString.getBytes().length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + byte[] errorBytes = errorString.getBytes(); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static RpcFailure decode(ByteBuf buf) { + long tag = buf.readLong(); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new RpcFailure(tag, new String(errorBytes)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcFailure) { + RpcFailure o = (RpcFailure) other; + return tag == o.tag && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java new file mode 100644 index 0000000000000..0c6f8acdcdc4b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/RpcResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link org.apache.spark.network.protocol.request.RpcRequest} for a successful RPC. */ +public final class RpcResponse implements ServerResponse { + public final long tag; + public final byte[] response; + + public RpcResponse(long tag, byte[] response) { + this.tag = tag; + this.response = response; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { return 8 + 4 + response.length; } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(tag); + buf.writeInt(response.length); + buf.writeBytes(response); + } + + public static RpcResponse decode(ByteBuf buf) { + long tag = buf.readLong(); + int responseLen = buf.readInt(); + byte[] response = new byte[responseLen]; + buf.readBytes(response); + return new RpcResponse(tag, response); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return tag == o.tag && Arrays.equals(response, o.response); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("tag", tag) + .add("response", response) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java new file mode 100644 index 0000000000000..335f9e8ea69f9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages from server to client (usually in response to some + * {@link org.apache.spark.network.protocol.request.ClientRequest}. + */ +public interface ServerResponse extends Encodable { + /** Used to identify this response type. */ + Type type(); + + /** + * Preceding every serialized ServerResponse is the type, which allows us to deserialize + * the response. + */ + public static enum Type implements Encodable { + ChunkFetchSuccess(0), ChunkFetchFailure(1), RpcResponse(2), RpcFailure(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 response types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch(id) { + case 0: return ChunkFetchSuccess; + case 1: return ChunkFetchFailure; + case 2: return RpcResponse; + case 3: return RpcFailure; + default: throw new IllegalArgumentException("Unknown response type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java new file mode 100644 index 0000000000000..e06198284e620 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseDecoder.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; + +/** + * Decoder used by the client side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ServerResponseDecoder extends MessageToMessageDecoder { + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ServerResponse.Type msgType = ServerResponse.Type.decode(in); + ServerResponse decoded = decode(msgType, in); + assert decoded.type() == msgType; + out.add(decoded); + } + + private ServerResponse decode(ServerResponse.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchSuccess: + return ChunkFetchSuccess.decode(in); + + case ChunkFetchFailure: + return ChunkFetchFailure.decode(in); + + case RpcResponse: + return RpcResponse.decode(in); + + case RpcFailure: + return RpcFailure.decode(in); + + default: + throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java new file mode 100644 index 0000000000000..069f42463a8fe --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/response/ServerResponseEncoder.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol.response; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class ServerResponseEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(ServerResponseEncoder.class); + + @Override + public void encode(ChannelHandlerContext ctx, ServerResponse in, List out) { + Object body = null; + long bodyLength = 0; + + // Only ChunkFetchSuccesses have data besides the header. + // The body is used in order to enable zero-copy transfer for the payload. + if (in instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + try { + bodyLength = resp.buffer.size(); + body = resp.buffer.convertToNetty(); + } catch (Exception e) { + // Re-encode this message as BlockFetchFailure. + logger.error(String.format("Error opening block %s for client %s", + resp.streamChunkId, ctx.channel().remoteAddress()), e); + encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + return; + } + } + + ServerResponse.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + bodyLength; + ByteBuf header = ctx.alloc().buffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + out.add(header); + if (body != null && bodyLength > 0) { + out.add(body); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java new file mode 100644 index 0000000000000..04814d9a88c4a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Iterator; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * StreamManager which allows registration of an Iterator, which are individually + * fetched as chunks by the client. + */ +public class DefaultStreamManager extends StreamManager { + private final AtomicLong nextStreamId; + private final Map streams; + + /** State of a single stream. */ + private static class StreamState { + final Iterator buffers; + + int curChunk = 0; + + StreamState(Iterator buffers) { + this.buffers = buffers; + } + } + + public DefaultStreamManager() { + // Start with a random stream id to help identifying different streams. + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap(); + } + + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + StreamState state = streams.get(streamId); + if (chunkIndex != state.curChunk) { + throw new IllegalStateException(String.format( + "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); + } else if (!state.buffers.hasNext()) { + throw new IllegalStateException(String.format( + "Requested chunk index beyond end %s", chunkIndex)); + } + state.curChunk += 1; + return state.buffers.next(); + } + + @Override + public void connectionTerminated(long streamId) { + // Release all remaining buffers. + StreamState state = streams.remove(streamId); + if (state != null && state.buffers != null) { + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } + } + } + + public long registerStream(Iterator buffers) { + long myStreamId = nextStreamId.getAndIncrement(); + streams.put(myStreamId, new StreamState(buffers)); + return myStreamId; + } + + public void unregisterStream(long streamId) { + streams.remove(streamId); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000000000..abfbe66d875e8 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.client.RpcResponseCallback; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.SluiceClient}s. + */ +public interface RpcHandler { + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + */ + void receive(byte[] message, RpcResponseCallback callback); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java new file mode 100644 index 0000000000000..aa81271024156 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.io.Closeable; +import java.net.InetSocketAddress; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.request.ClientRequestDecoder; +import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.SluiceConfig; + +/** + * Server for the efficient, low-level streaming service. + */ +public class SluiceServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(SluiceServer.class); + + private final SluiceConfig conf; + private final StreamManager streamManager; + private final RpcHandler rpcHandler; + + private ServerBootstrap bootstrap; + private ChannelFuture channelFuture; + private int port; + + public SluiceServer(SluiceConfig conf, StreamManager streamManager, RpcHandler rpcHandler) { + this.conf = conf; + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + + init(); + } + + public int getPort() { return port; } + + private void init() { + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + EventLoopGroup bossGroup = + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + EventLoopGroup workerGroup = bossGroup; + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NettyUtils.getServerChannelClass(ioMode)) + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + + if (conf.backLog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); + } + + if (conf.receiveBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + + bootstrap.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("clientRequestDecoder", new ClientRequestDecoder()) + .addLast("serverResponseEncoder", new ServerResponseEncoder()) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", new SluiceServerHandler(streamManager, rpcHandler)); + } + }); + + channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture.syncUninterruptibly(); + + port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + logger.debug("Shuffle server started on port :" + port); + } + + @Override + public void close() { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly(); + channelFuture = null; + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(); + } + bootstrap = null; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java b/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java new file mode 100644 index 0000000000000..fad72fbfc711b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/SluiceServerHandler.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Set; + +import com.google.common.base.Throwables; +import com.google.common.collect.Sets; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.ClientRequest; +import org.apache.spark.network.protocol.request.RpcRequest; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.RpcFailure; +import org.apache.spark.network.protocol.response.RpcResponse; + +/** + * A handler that processes requests from clients and writes chunk data back. Each handler keeps + * track of which streams have been fetched via this channel, in order to clean them up if the + * channel is terminated (see #channelUnregistered). + * + * The messages should have been processed by the pipeline setup by {@link SluiceServer}. + */ +public class SluiceServerHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(SluiceServerHandler.class); + + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + + /** Handles all RPC messages. */ + private final RpcHandler rpcHandler; + + /** List of all stream ids that have been read on this handler, used for cleanup. */ + private final Set streamIds; + + public SluiceServerHandler(StreamManager streamManager, RpcHandler rpcHandler) { + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + this.streamIds = Sets.newHashSet(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.error("Exception in connection from " + ctx.channel().remoteAddress(), cause); + ctx.close(); + super.exceptionCaught(ctx, cause); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + super.channelUnregistered(ctx); + // Inform the StreamManager that these streams will no longer be read from. + for (long streamId : streamIds) { + streamManager.connectionTerminated(streamId); + } + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ClientRequest request) { + if (request instanceof ChunkFetchRequest) { + processFetchRequest(ctx, (ChunkFetchRequest) request); + } else if (request instanceof RpcRequest) { + processRpcRequest(ctx, (RpcRequest) request); + } else { + throw new IllegalArgumentException("Unknown request type: " + request); + } + } + + private void processFetchRequest(final ChannelHandlerContext ctx, final ChunkFetchRequest req) { + final String client = ctx.channel().remoteAddress().toString(); + streamIds.add(req.streamChunkId.streamId); + + logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + + ManagedBuffer buf; + try { + buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format( + "Error opening block %s for request from %s", req.streamChunkId, client), e); + respond(ctx, new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(ctx, new ChunkFetchSuccess(req.streamChunkId, buf)); + } + + private void processRpcRequest(final ChannelHandlerContext ctx, final RpcRequest req) { + try { + rpcHandler.receive(req.message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + respond(ctx, new RpcResponse(req.tag, response)); + } + + @Override + public void onFailure(Throwable e) { + respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC tag " + req.tag, e); + respond(ctx, new RpcFailure(req.tag, Throwables.getStackTraceAsString(e))); + } + } + + /** + * Responds to a single message with some Encodable object. If a failure occurs while sending, + * it will be logged and the channel closed. + */ + private void respond(final ChannelHandlerContext ctx, final Encodable result) { + final String remoteAddress = ctx.channel().remoteAddress().toString(); + ctx.writeAndFlush(result).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + ctx.close(); + } + } + } + ); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java new file mode 100644 index 0000000000000..2e07f5a270cb9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * The StreamManager is used to fetch individual chunks from a stream. This is used in + * {@link SluiceServerHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of Sluice, but a given stream is guaranteed to be read by only one + * client connection, meaning that getChunk() for a particular stream will be called serially and + * that once the connection associated with the stream is closed, that stream will never be used + * again. + */ +public abstract class StreamManager { + /** + * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the + * client. A single stream will be associated with a single TCP connection, so this method + * will not be called in parallel for a particular stream. + * + * Chunks may be requested in any order, and requests may be repeated, but it is not required + * that implementations support this behavior. + * + * The returned ManagedBuffer will be release()'d after being written to the network. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param chunkIndex 0-indexed chunk of the stream that's requested + */ + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + + /** + * Indicates that the TCP connection that was tied to the given stream has been terminated. After + * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned + * up. + */ + public void connectionTerminated(long streamId) { } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java new file mode 100644 index 0000000000000..2dc0e248ae835 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** + * Provides a mechanism for constructing a {@link SluiceConfig} using some sort of configuration. + */ +public abstract class ConfigProvider { + /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ + public abstract String get(String name); + + public String get(String name, String defaultValue) { + try { + return get(name); + } catch (NoSuchElementException e) { + return defaultValue; + } + } + + public int getInt(String name, int defaultValue) { + return Integer.parseInt(get(name, Integer.toString(defaultValue))); + } + + public long getLong(String name, long defaultValue) { + return Long.parseLong(get(name, Long.toString(defaultValue))); + } + + public double getDouble(String name, double defaultValue) { + return Double.parseDouble(get(name, Double.toString(defaultValue))); + } + + public boolean getBoolean(String name, boolean defaultValue) { + return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java new file mode 100644 index 0000000000000..cef88c0091eff --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/DefaultConfigProvider.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** Uses System properties to obtain config values. */ +public class DefaultConfigProvider extends ConfigProvider { + @Override + public String get(String name) { + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java new file mode 100644 index 0000000000000..91cb3e0e6f8f6 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * Selector for which form of low-level IO we should use. + * NIO is always available, while EPOLL is only available on certain machines. + * AUTO is used to select EPOLL if it's available, or NIO otherwise. + */ +public enum IOMode { + NIO, EPOLL, AUTO +} diff --git a/core/src/main/scala/org/apache/spark/network/exceptions.scala b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 65% rename from core/src/main/scala/org/apache/spark/network/exceptions.scala rename to network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index d918d358c4adb..fafdcad04aeb6 100644 --- a/core/src/main/scala/org/apache/spark/network/exceptions.scala +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.util; -class BlockFetchFailureException(blockId: String, errorMsg: String, cause: Throwable) - extends Exception(errorMsg, cause) { +import java.io.Closeable; - def this(blockId: String, errorMsg: String) = this(blockId, errorMsg, null) -} - - -class BlockUploadFailureException(blockId: String, cause: Throwable) - extends Exception(s"Failed to fetch block $blockId", cause) { +import com.google.common.io.Closeables; - def this(blockId: String) = this(blockId, null) +public class JavaUtils { + /** Closes the given object, ignoring IOExceptions. */ + @SuppressWarnings("deprecation") + public static void closeQuietly(Closeable closable) { + Closeables.closeQuietly(closable); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java new file mode 100644 index 0000000000000..3d20dc9e1c1cd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.concurrent.ThreadFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. + */ +public class NettyUtils { + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); + + switch(mode) { + case NIO: + return new NioEventLoopGroup(numThreads, threadFactory); + case EPOLL: + return new EpollEventLoopGroup(numThreads, threadFactory); + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct (client) SocketChannel class based on IOMode. */ + public static Class getClientChannelClass(IOMode mode) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + switch(mode) { + case NIO: + return NioSocketChannel.class; + case EPOLL: + return EpollSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct ServerSocketChannel class based on IOMode. */ + public static Class getServerChannelClass(IOMode mode) { + if (mode == IOMode.AUTO) { + mode = autoselectMode(); + } + switch(mode) { + case NIO: + return NioServerSocketChannel.class; + case EPOLL: + return EpollServerSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** + * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. + * This is used before all decoders. + */ + public static ByteToMessageDecoder createFrameDecoder() { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + } + + /** Returns EPOLL if it's available on this system, NIO otherwise. */ + private static IOMode autoselectMode() { + return Epoll.isAvailable() ? IOMode.EPOLL : IOMode.NIO; + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java similarity index 58% rename from core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala rename to network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java index 7c3074e939794..26fa3229c4721 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/network/common/src/main/java/org/apache/spark/network/util/SluiceConfig.java @@ -15,35 +15,37 @@ * limitations under the License. */ -package org.apache.spark.network.netty - -import org.apache.spark.SparkConf +package org.apache.spark.network.util; /** - * A central location that tracks all the settings we exposed to users. + * A central location that tracks all the settings we expose to users. */ -private[spark] -class NettyConfig(conf: SparkConf) { +public class SluiceConfig { + private final ConfigProvider conf; + + public SluiceConfig(ConfigProvider conf) { + this.conf = conf; + } /** Port the server listens on. Default to a random port. */ - private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) + public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } - /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ - private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase + /** IO mode: nio, epoll, or auto (try epoll first and then nio). */ + public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } /** Connect timeout in secs. Default 120 secs. */ - private[netty] val connectTimeoutMs = { - conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000 + public int connectionTimeoutMs() { + return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; } - /** Requested maximum length of the queue of incoming connections. */ - private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - private[netty] val serverThreads: Int = conf.getInt("spark.shuffle.io.serverThreads", 0) + public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - private[netty] val clientThreads: Int = conf.getInt("spark.shuffle.io.clientThreads", 0) + public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } /** * Receive buffer size (SO_RCVBUF). @@ -52,10 +54,8 @@ class NettyConfig(conf: SparkConf) { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.receiveBuffer").map(_.toInt) + public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } /** Send buffer size (SO_SNDBUF). */ - private[netty] val sendBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } } diff --git a/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java new file mode 100644 index 0000000000000..d20528558cae1 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/IntegrationSuite.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.SluiceConfig; + +public class IntegrationSuite { + static final long STREAM_ID = 1; + static final int BUFFER_CHUNK_INDEX = 0; + static final int FILE_CHUNK_INDEX = 1; + + static SluiceServer server; + static SluiceClientFactory clientFactory; + static StreamManager streamManager; + static File testFile; + + static ManagedBuffer bufferChunk; + static ManagedBuffer fileChunk; + + @BeforeClass + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + fp.close(); + fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + + SluiceConfig conf = new SluiceConfig(new DefaultConfigProvider()); + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + server = new SluiceServer(conf, streamManager, new NoOpRpcHandler()); + clientFactory = new SluiceClientFactory(conf); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + testFile.delete(); + } + + class FetchResult { + public Set successChunks; + public Set failedChunks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + private FetchResult fetchChunks(List chunkIndices) throws Exception { + SluiceClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void fetchBufferChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchFileChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchBothChunks() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + + @Test + public void fetchChunkAndNonExistent() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), list1.get(i)); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java new file mode 100644 index 0000000000000..af35709319957 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -0,0 +1,26 @@ +package org.apache.spark.network;/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.server.RpcHandler; + +public class NoOpRpcHandler implements RpcHandler { + @Override + public void receive(byte[] message, RpcResponseCallback callback) { + callback.onSuccess(new byte[0]); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java new file mode 100644 index 0000000000000..cf74a9d8993fe --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.request.ChunkFetchRequest; +import org.apache.spark.network.protocol.request.ClientRequest; +import org.apache.spark.network.protocol.request.ClientRequestDecoder; +import org.apache.spark.network.protocol.request.ClientRequestEncoder; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; +import org.apache.spark.network.protocol.response.ServerResponse; +import org.apache.spark.network.protocol.response.ServerResponseDecoder; +import org.apache.spark.network.protocol.response.ServerResponseEncoder; +import org.apache.spark.network.util.NettyUtils; + +public class ProtocolSuite { + private void testServerToClient(ServerResponse msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new ServerResponseEncoder()); + serverChannel.writeOutbound(msg); + + EmbeddedChannel clientChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new ServerResponseDecoder()); + + while (!serverChannel.outboundMessages().isEmpty()) { + clientChannel.writeInbound(serverChannel.readOutbound()); + } + + assertEquals(1, clientChannel.inboundMessages().size()); + assertEquals(msg, clientChannel.readInbound()); + } + + private void testClientToServer(ClientRequest msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new ClientRequestEncoder()); + clientChannel.writeOutbound(msg); + + EmbeddedChannel serverChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new ClientRequestDecoder()); + + while (!clientChannel.outboundMessages().isEmpty()) { + serverChannel.writeInbound(clientChannel.readOutbound()); + } + + assertEquals(1, serverChannel.inboundMessages().size()); + assertEquals(msg, serverChannel.readInbound()); + } + + @Test + public void s2cChunkFetchSuccess() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + } + + @Test + public void s2cBlockFetchFailure() { + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + } + + @Test + public void c2sChunkFetchRequest() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java new file mode 100644 index 0000000000000..e6b59b9ad8e5c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientFactorySuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.concurrent.TimeoutException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.client.SluiceClient; +import org.apache.spark.network.client.SluiceClientFactory; +import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.SluiceServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.DefaultConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SluiceConfig; + +public class SluiceClientFactorySuite { + private SluiceConfig conf; + private SluiceServer server1; + private SluiceServer server2; + + @Before + public void setUp() { + conf = new SluiceConfig(new DefaultConfigProvider()); + StreamManager streamManager = new DefaultStreamManager(); + RpcHandler rpcHandler = new NoOpRpcHandler(); + server1 = new SluiceServer(conf, streamManager, rpcHandler); + server2 = new SluiceServer(conf, streamManager, rpcHandler); + } + + @After + public void tearDown() { + JavaUtils.closeQuietly(server1); + JavaUtils.closeQuietly(server2); + } + + @Test + public void createAndReuseBlockClients() throws TimeoutException { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c3.isActive()); + assertTrue(c1 == c2); + assertTrue(c1 != c3); + factory.close(); + } + + @Test + public void neverReturnInactiveClients() throws Exception { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + c1.close(); + + long start = System.currentTimeMillis(); + while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertFalse(c1 == c2); + assertTrue(c2.isActive()); + factory.close(); + } + + @Test + public void closeBlockClientsWithFactory() throws TimeoutException { + SluiceClientFactory factory = new SluiceClientFactory(conf); + SluiceClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + SluiceClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c2.isActive()); + factory.close(); + assertFalse(c1.isActive()); + assertFalse(c2.isActive()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java new file mode 100644 index 0000000000000..cab0597fb948a --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SluiceClientHandlerSuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.SluiceClientHandler; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.response.ChunkFetchFailure; +import org.apache.spark.network.protocol.response.ChunkFetchSuccess; + +public class SluiceClientHandlerSuite { + @Test + public void handleSuccessfulFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } + + @Test + public void handleFailedFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.writeInbound(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } + + @Test + public void clearAllOutstandingRequests() { + SluiceClientHandler handler = new SluiceClientHandler(); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.writeInbound(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + channel.pipeline().fireExceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + assertFalse(channel.finish()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java new file mode 100644 index 0000000000000..7e7554af70f42 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +public class TestManagedBuffer extends ManagedBuffer { + + private final int len; + private NettyManagedBuffer underlying; + + public TestManagedBuffer(int len) { + Preconditions.checkArgument(len <= Byte.MAX_VALUE); + this.len = len; + byte[] byteArray = new byte[len]; + for (int i = 0; i < len; i ++) { + byteArray[i] = (byte) i; + } + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + } + + + @Override + public long size() { + return underlying.size(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return underlying.nioByteBuffer(); + } + + @Override + public InputStream inputStream() throws IOException { + return underlying.inputStream(); + } + + @Override + public ManagedBuffer retain() { + underlying.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + underlying.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return underlying.convertToNetty(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ManagedBuffer) { + try { + ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); + if (nioBuf.remaining() != len) { + return false; + } else { + for (int i = 0; i < len; i ++) { + if (nioBuf.get() != i) { + return false; + } + } + return true; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java new file mode 100644 index 0000000000000..56a2b805f154c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.net.InetAddress; + +public class TestUtils { + public static String getLocalHost() { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/pom.xml b/pom.xml index 7756c89b00cad..b0d39cfec1e8d 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ graphx mllib tools + network/common streaming sql/catalyst sql/core diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8a1b2d3b91327..71041e7fe1a14 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -51,8 +51,6 @@ object MimaExcludes { // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.network.netty.PathResolver"), ProblemFilters.exclude[MissingClassProblem]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 7149dbc12a365..190373e0cb5f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -122,7 +122,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { sender: ActorRef ) { if (!receiverInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + throw new Exception("Register received for unexpected type " + streamId) } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverActor, true, host)