diff --git a/core/pom.xml b/core/pom.xml
index a5a178079bc57..aff0d989d01bb 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -44,6 +44,11 @@
+
+ org.apache.spark
+ network
+ ${project.version}
+ net.java.dev.jets3tjets3t
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 373ce795a309e..867173e04714e 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
-import org.apache.spark.network.netty.NettyBlockTransferService
+import org.apache.spark.network.netty.{NettyBlockTransferService}
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -40,7 +40,6 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
-
/**
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -233,12 +232,14 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
- // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec.
- val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) {
- new NettyBlockTransferService(conf)
- } else {
- new NioBlockTransferService(conf, securityManager)
- }
+ // TODO: This is only netty by default for initial testing -- it should not be merged as such!!!
+ val blockTransferService =
+ conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
+ case "netty" =>
+ new NettyBlockTransferService(conf)
+ case "nio" =>
+ new NioBlockTransferService(conf, securityManager)
+ }
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index 0eeffe0e7c5e6..1745d52c81923 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,8 +17,8 @@
package org.apache.spark.network
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.storage.{BlockId, StorageLevel}
private[spark]
trait BlockDataManager {
@@ -27,10 +27,10 @@ trait BlockDataManager {
* Interface to get local block data. Throws an exception if the block cannot be found or
* cannot be read successfully.
*/
- def getBlockData(blockId: String): ManagedBuffer
+ def getBlockData(blockId: BlockId): ManagedBuffer
/**
* Put the block locally, using the given storage level.
*/
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+ def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
index dd70e26647939..e35fdb4e95899 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
@@ -19,6 +19,8 @@ package org.apache.spark.network
import java.util.EventListener
+import org.apache.spark.network.buffer.ManagedBuffer
+
/**
* Listener callback interface for [[BlockTransferService.fetchBlocks]].
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index d3ed683c7e880..8287a0fc81cfe 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -18,16 +18,18 @@
package org.apache.spark.network
import java.io.Closeable
-import java.nio.ByteBuffer
+
+import org.apache.spark.network.buffer.ManagedBuffer
import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration
+import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.util.Utils
private[spark]
-abstract class BlockTransferService extends Closeable {
+abstract class BlockTransferService extends Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -92,10 +94,7 @@ abstract class BlockTransferService extends Closeable {
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
lock.synchronized {
- val ret = ByteBuffer.allocate(data.size.toInt)
- ret.put(data.nioByteBuffer())
- ret.flip()
- result = Left(new NioManagedBuffer(ret))
+ result = Left(data)
lock.notify()
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
deleted file mode 100644
index dd808d2500fbc..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.channels.FileChannel.MapMode
-
-import scala.util.Try
-
-import com.google.common.io.ByteStreams
-import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf}
-import io.netty.channel.DefaultFileRegion
-
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
-
-
-/**
- * This interface provides an immutable view for data in the form of bytes. The implementation
- * should specify how the data is provided:
- *
- * - [[FileSegmentManagedBuffer]]: data backed by part of a file
- * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer
- * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf
- *
- * The concrete buffer implementation might be managed outside the JVM garbage collector.
- * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted.
- * In that case, if the buffer is going to be passed around to a different thread, retain/release
- * should be called.
- */
-private[spark]
-abstract class ManagedBuffer {
- // Note that all the methods are defined with parenthesis because their implementations can
- // have side effects (io operations).
-
- /** Number of bytes of the data. */
- def size: Long
-
- /**
- * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
- * returned ByteBuffer should not affect the content of this buffer.
- */
- def nioByteBuffer(): ByteBuffer
-
- /**
- * Exposes this buffer's data as an InputStream. The underlying implementation does not
- * necessarily check for the length of bytes read, so the caller is responsible for making sure
- * it does not go over the limit.
- */
- def inputStream(): InputStream
-
- /**
- * Increment the reference count by one if applicable.
- */
- def retain(): this.type
-
- /**
- * If applicable, decrement the reference count by one and deallocates the buffer if the
- * reference count reaches zero.
- */
- def release(): this.type
-
- /**
- * Convert the buffer into an Netty object, used to write the data out.
- */
- private[network] def convertToNetty(): AnyRef
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a segment in a file
- */
-private[spark]
-final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
- extends ManagedBuffer {
-
- override def size: Long = length
-
- override def nioByteBuffer(): ByteBuffer = {
- var channel: FileChannel = null
- try {
- channel = new RandomAccessFile(file, "r").getChannel
- channel.map(MapMode.READ_ONLY, offset, length)
- } catch {
- case e: IOException =>
- Try(channel.size).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- } finally {
- if (channel != null) {
- Utils.tryLog(channel.close())
- }
- }
- }
-
- override def inputStream(): InputStream = {
- var is: FileInputStream = null
- try {
- is = new FileInputStream(file)
- is.skip(offset)
- ByteStreams.limit(is, length)
- } catch {
- case e: IOException =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- Try(file.length).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- case e: Throwable =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- throw e
- }
- }
-
- override def toString: String = s"${getClass.getName}($file, $offset, $length)"
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
- */
-private[spark]
-final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
-
- override def size: Long = buf.remaining()
-
- override def nioByteBuffer() = buf.duplicate()
-
- override def inputStream() = new ByteBufferInputStream(buf)
-
- private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf)
-
- // [[ByteBuffer]] is managed by the JVM garbage collector itself.
- override def retain(): this.type = this
- override def release(): this.type = this
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
- */
-private[spark]
-final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
-
- override def size: Long = buf.readableBytes()
-
- override def nioByteBuffer() = buf.nioBuffer()
-
- override def inputStream() = new ByteBufInputStream(buf)
-
- private[network] override def convertToNetty(): AnyRef = buf
-
- override def retain(): this.type = {
- buf.retain()
- this
- }
-
- override def release(): this.type = {
- buf.release()
- this
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala
deleted file mode 100644
index 6bdbf88d337ce..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.Closeable
-import java.util.concurrent.TimeoutException
-
-import scala.concurrent.{Future, promise}
-
-import io.netty.channel.{ChannelFuture, ChannelFutureListener}
-
-import org.apache.spark.Logging
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener}
-import org.apache.spark.storage.StorageLevel
-
-
-/**
- * Client for [[NettyBlockTransferService]]. The connection to server must have been established
- * using [[BlockClientFactory]] before instantiating this.
- *
- * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible
- * for handling responses from the server.
- *
- * Concurrency: thread safe and can be called from multiple threads.
- *
- * @param cf the ChannelFuture for the connection.
- * @param handler [[BlockClientHandler]] for handling outstanding requests.
- */
-@throws[TimeoutException]
-private[netty]
-class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging {
-
- private[this] val serverAddr = cf.channel().remoteAddress().toString
-
- def isActive: Boolean = cf.channel().isActive
-
- /**
- * Ask the remote server for a sequence of blocks, and execute the callback.
- *
- * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
- * rate of fetching; otherwise we could run out of memory due to large outstanding fetches.
- *
- * @param blockIds sequence of block ids to fetch.
- * @param listener callback to fire on fetch success / failure.
- */
- def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = {
- var startTime: Long = 0
- logTrace {
- startTime = System.currentTimeMillis()
- s"Sending request $blockIds to $serverAddr"
- }
-
- blockIds.foreach { blockId =>
- handler.addFetchRequest(blockId, listener)
- }
-
- cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = System.currentTimeMillis() - startTime
- s"Sending request $blockIds to $serverAddr took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- blockIds.foreach { blockId =>
- handler.removeFetchRequest(blockId)
- listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg))
- }
- }
- }
- })
- }
-
- def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] =
- {
- var startTime: Long = 0
- logTrace {
- startTime = System.currentTimeMillis()
- s"Uploading block ($blockId) to $serverAddr"
- }
- val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel))
-
- val p = promise[Unit]()
- handler.addUploadRequest(blockId, p)
- f.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = System.currentTimeMillis() - startTime
- s"Uploading block ($blockId) to $serverAddr took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- }
- }
- })
-
- p.future
- }
-
- /** Close the connection. This does NOT block till the connection is closed. */
- def close(): Unit = cf.channel().close()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala
deleted file mode 100644
index 8021cfdf42d1a..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.Closeable
-import java.util.concurrent.{ConcurrentHashMap, TimeoutException}
-
-import io.netty.bootstrap.Bootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel._
-import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioSocketChannel
-import io.netty.util.internal.PlatformDependent
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.util.Utils
-
-
-/**
- * Factory for creating [[BlockClient]] by using createClient.
- *
- * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]]
- * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s.
- */
-private[netty]
-class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable {
-
- def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
-
- /** A thread factory so the threads are named (for debugging). */
- private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client")
-
- /** Socket channel type, initialized by [[init]] depending ioMode. */
- private[this] var socketChannelClass: Class[_ <: Channel] = _
-
- /** Thread pool shared by all clients. */
- private[this] var workerGroup: EventLoopGroup = _
-
- private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient]
-
- // The encoders are stateless and can be shared among multiple clients.
- private[this] val encoder = new ClientRequestEncoder
- private[this] val decoder = new ServerResponseDecoder
-
- init()
-
- /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
- private def init(): Unit = {
- def initNio(): Unit = {
- socketChannelClass = classOf[NioSocketChannel]
- workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory)
- }
- def initEpoll(): Unit = {
- socketChannelClass = classOf[EpollSocketChannel]
- workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory)
- }
-
- // For auto mode, first try epoll (only available on Linux), then nio.
- conf.ioMode match {
- case "nio" => initNio()
- case "epoll" => initEpoll()
- case "auto" => if (Epoll.isAvailable) initEpoll() else initNio()
- }
- }
-
- /**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
- *
- * This blocks until a connection is successfully established.
- *
- * Concurrency: This method is safe to call from multiple threads.
- */
- def createClient(remoteHost: String, remotePort: Int): BlockClient = {
- // Get connection from the connection pool first.
- // If it is not found or not active, create a new one.
- val cachedClient = connectionPool.get((remoteHost, remotePort))
- if (cachedClient != null && cachedClient.isActive) {
- return cachedClient
- }
-
- logDebug(s"Creating new connection to $remoteHost:$remotePort")
-
- // There is a chance two threads are creating two different clients connecting to the same host.
- // But that's probably ok ...
-
- val handler = new BlockClientHandler
-
- val bootstrap = new Bootstrap
- bootstrap.group(workerGroup)
- .channel(socketChannelClass)
- // Disable Nagle's Algorithm since we don't want packets to wait
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs)
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator())
-
- bootstrap.handler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("clientRequestEncoder", encoder)
- .addLast("frameDecoder", ProtocolUtils.createFrameDecoder())
- .addLast("serverResponseDecoder", decoder)
- .addLast("handler", handler)
- }
- })
-
- // Connect to the remote server
- val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort)
- if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) {
- throw new TimeoutException(
- s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)")
- }
-
- val client = new BlockClient(cf, handler)
- connectionPool.put((remoteHost, remotePort), client)
- client
- }
-
- /** Close all connections in the connection pool, and shutdown the worker thread pool. */
- override def close(): Unit = {
- val iter = connectionPool.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- entry.getValue.close()
- connectionPool.remove(entry.getKey)
- }
-
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- }
- }
-
- /**
- * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
- * are disabled because the ByteBufs are allocated by the event loop thread, but released by the
- * executor thread rather than the event loop thread. Those thread-local caches actually delay
- * the recycling of buffers, leading to larger memory usage.
- */
- private def createPooledByteBufAllocator(): PooledByteBufAllocator = {
- def getPrivateStaticField(name: String): Int = {
- val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name)
- f.setAccessible(true)
- f.getInt(null)
- }
- new PooledByteBufAllocator(
- PlatformDependent.directBufferPreferred(),
- getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
- getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
- getPrivateStaticField("DEFAULT_PAGE_SIZE"),
- getPrivateStaticField("DEFAULT_MAX_ORDER"),
- 0, // tinyCacheSize
- 0, // smallCacheSize
- 0 // normalCacheSize
- )
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala
deleted file mode 100644
index 5e28a07a461fa..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.util.concurrent.ConcurrentHashMap
-
-import scala.concurrent.Promise
-
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener}
-
-
-/**
- * Handler that processes server responses, in response to requests issued from [[BlockClient]].
- * It works by tracking the list of outstanding requests (and their callbacks).
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-private[netty]
-class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging {
-
- /** Tracks the list of outstanding requests and their listeners on success/failure. */
- private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] =
- new ConcurrentHashMap[String, BlockFetchingListener]
-
- private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] =
- new ConcurrentHashMap[String, Promise[Unit]]
-
- def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = {
- outstandingFetches.put(blockId, listener)
- }
-
- def removeFetchRequest(blockId: String): Unit = {
- outstandingFetches.remove(blockId)
- }
-
- def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = {
- outstandingUploads.put(blockId, promise)
- }
-
- /**
- * Fire the failure callback for all outstanding requests. This is called when we have an
- * uncaught exception or pre-mature connection termination.
- */
- private def failOutstandingRequests(cause: Throwable): Unit = {
- val iter1 = outstandingFetches.entrySet().iterator()
- while (iter1.hasNext) {
- val entry = iter1.next()
- entry.getValue.onBlockFetchFailure(entry.getKey, cause)
- }
- // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests
- // as well. But I guess that is ok given the caller will fail as soon as any requests fail.
- outstandingFetches.clear()
-
- val iter2 = outstandingUploads.entrySet().iterator()
- while (iter2.hasNext) {
- val entry = iter2.next()
- entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}"))
- }
- outstandingUploads.clear()
- }
-
- override def channelUnregistered(ctx: ChannelHandlerContext): Unit = {
- if (outstandingFetches.size() > 0) {
- logError("Still have " + outstandingFetches.size() + " requests outstanding " +
- s"when connection from ${ctx.channel.remoteAddress} is closed")
- failOutstandingRequests(new RuntimeException(
- s"Connection from ${ctx.channel.remoteAddress} closed"))
- }
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- if (outstandingFetches.size() > 0) {
- logError(
- s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause)
- failOutstandingRequests(cause)
- }
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) {
- val server = ctx.channel.remoteAddress.toString
- response match {
- case BlockFetchSuccess(blockId, buf) =>
- val listener = outstandingFetches.get(blockId)
- if (listener == null) {
- logWarning(s"Got a response for block $blockId from $server but it is not outstanding")
- buf.release()
- } else {
- outstandingFetches.remove(blockId)
- listener.onBlockFetchSuccess(blockId, buf)
- buf.release()
- }
- case BlockFetchFailure(blockId, errorMsg) =>
- val listener = outstandingFetches.get(blockId)
- if (listener == null) {
- logWarning(
- s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding")
- } else {
- outstandingFetches.remove(blockId)
- listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg))
- }
- case BlockUploadSuccess(blockId) =>
- val p = outstandingUploads.get(blockId)
- if (p == null) {
- logWarning(s"Got a response for upload $blockId from $server but it is not outstanding")
- } else {
- outstandingUploads.remove(blockId)
- p.success(Unit)
- }
- case BlockUploadFailure(blockId, error) =>
- val p = outstandingUploads.get(blockId)
- if (p == null) {
- logWarning(s"Got a response for upload $blockId from $server but it is not outstanding")
- } else {
- outstandingUploads.remove(blockId)
- p.failure(new BlockUploadFailureException(blockId))
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala
deleted file mode 100644
index e2eb7c379f14d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.Closeable
-import java.net.InetSocketAddress
-
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioServerSocketChannel
-import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption}
-
-import org.apache.spark.Logging
-import org.apache.spark.network.BlockDataManager
-import org.apache.spark.util.Utils
-
-
-/**
- * Server for the [[NettyBlockTransferService]].
- */
-private[netty]
-class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager)
- extends Closeable with Logging {
-
- def port: Int = _port
-
- def hostName: String = _hostName
-
- private var _port: Int = conf.serverPort
- private var _hostName: String = ""
- private var bootstrap: ServerBootstrap = _
- private var channelFuture: ChannelFuture = _
-
- init()
-
- /** Initialize the server. */
- private def init(): Unit = {
- bootstrap = new ServerBootstrap
- val threadFactory = Utils.namedThreadFactory("spark-netty-server")
-
- // Use only one thread to accept connections, and 2 * num_cores for worker.
- def initNio(): Unit = {
- val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory)
- val workerGroup = bossGroup
- bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
- }
- def initEpoll(): Unit = {
- val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory)
- val workerGroup = bossGroup
- bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "epoll" => initEpoll()
- case "auto" => if (Epoll.isAvailable) initEpoll() else initNio()
- }
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-
- // Various (advanced) user-configured settings.
- conf.backLog.foreach { backLog =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
- }
- conf.receiveBuf.foreach { receiveBuf =>
- bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
- }
- conf.sendBuf.foreach { sendBuf =>
- bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
- }
-
- bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", ProtocolUtils.createFrameDecoder())
- .addLast("clientRequestDecoder", new ClientRequestDecoder)
- .addLast("serverResponseEncoder", new ServerResponseEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
- })
-
- channelFuture = bootstrap.bind(new InetSocketAddress(_port))
- channelFuture.sync()
-
- val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
- _port = addr.getPort
- // _hostName = addr.getHostName
- _hostName = Utils.localHostName()
-
- logInfo(s"Server started ${_hostName}:${_port}")
- }
-
- /** Shutdown the server. */
- def close(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bootstrap != null && bootstrap.group() != null) {
- bootstrap.group().shutdownGracefully()
- }
- if (bootstrap != null && bootstrap.childGroup() != null) {
- bootstrap.childGroup().shutdownGracefully()
- }
- bootstrap = null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala
deleted file mode 100644
index 44687f0b770e9..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import io.netty.channel._
-
-import org.apache.spark.Logging
-import org.apache.spark.network.{ManagedBuffer, BlockDataManager}
-import org.apache.spark.storage.StorageLevel
-
-
-/**
- * A handler that processes requests from clients and writes block data back.
- *
- * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer.
- */
-private[netty] class BlockServerHandler(dataProvider: BlockDataManager)
- extends SimpleChannelInboundHandler[ClientRequest] with Logging {
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = {
- request match {
- case BlockFetchRequest(blockIds) =>
- blockIds.foreach(processFetchRequest(ctx, _))
- case BlockUploadRequest(blockId, data, level) =>
- processUploadRequest(ctx, blockId, data, level)
- }
- } // end of channelRead0
-
- private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = {
- // A helper function to send error message back to the client.
- def client = ctx.channel.remoteAddress.toString
-
- def respondWithError(error: String): Unit = {
- ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (!future.isSuccess) {
- // TODO: Maybe log the success case as well.
- logError(s"Error sending error back to $client", future.cause)
- ctx.close()
- }
- }
- }
- )
- }
-
- logTrace(s"Received request from $client to fetch block $blockId")
-
- // First make sure we can find the block. If not, send error back to the user.
- var buf: ManagedBuffer = null
- try {
- buf = dataProvider.getBlockData(blockId)
- } catch {
- case e: Exception =>
- logError(s"Error opening block $blockId for request from $client", e)
- respondWithError(e.getMessage)
- return
- }
-
- ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${buf.size} B) back to $client")
- } else {
- logError(
- s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- }
- )
- } // end of processBlockRequest
-
- private def processUploadRequest(
- ctx: ChannelHandlerContext,
- blockId: String,
- data: ManagedBuffer,
- level: StorageLevel): Unit = {
- // A helper function to send error message back to the client.
- def client = ctx.channel.remoteAddress.toString
-
- try {
- dataProvider.putBlockData(blockId, data, level)
- ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (!future.isSuccess) {
- logError(s"Error sending an ACK back to client $client")
- }
- }
- })
- } catch {
- case e: Throwable =>
- logError(s"Error processing uploaded block $blockId", e)
- ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (!future.isSuccess) {
- logError(s"Error sending an ACK back to client $client")
- }
- }
- })
- }
- } // end of processUploadRequest
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
new file mode 100644
index 0000000000000..aefd8a6335b2a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+import java.util
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockFetchingListener
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient}
+import org.apache.spark.storage.BlockId
+import org.apache.spark.util.Utils
+
+/**
+ * Responsible for holding the state for a request for a single set of blocks. This assumes that
+ * the chunks will be returned in the same order as requested, and that there will be exactly
+ * one chunk per block.
+ *
+ * Upon receipt of any block, the listener will be called back. Upon failure part way through,
+ * the listener will receive a failure callback for each outstanding block.
+ */
+class NettyBlockFetcher(
+ serializer: Serializer,
+ client: SluiceClient,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener)
+ extends Logging {
+
+ require(blockIds.nonEmpty)
+
+ val ser = serializer.newInstance()
+
+ var streamHandle: ShuffleStreamHandle = _
+
+ val chunkCallback = new ChunkReceivedCallback {
+ // On receipt of a chunk, pass it upwards as a block.
+ def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions {
+ buffer.retain()
+ listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer)
+ }
+
+ // On receipt of a failure, fail every block from chunkIndex onwards.
+ def onFailure(chunkIndex: Int, e: Throwable): Unit = {
+ blockIds.drop(chunkIndex).foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
+ // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle.
+ client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(),
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ try {
+ streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response))
+ logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.")
+
+ // Immediately request all chunks -- we expect that the total size of the request is
+ // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
+ for (i <- 0 until streamHandle.numChunks) {
+ client.fetchChunk(streamHandle.streamId, i, chunkCallback)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Failed while starting block fetches", e)
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ logError("Failed while starting block fetches")
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ })
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
new file mode 100644
index 0000000000000..c8658ec98b82c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.RpcResponseCallback
+import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
+import org.apache.spark.storage.BlockId
+
+import scala.collection.JavaConversions._
+
+/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
+case class OpenBlocks(blockIds: Seq[BlockId])
+
+/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
+case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
+
+/**
+ * Serves requests to open blocks by simply registering one chunk per block requested.
+ */
+class NettyBlockRpcServer(
+ serializer: Serializer,
+ streamManager: DefaultStreamManager,
+ blockManager: BlockDataManager)
+ extends RpcHandler with Logging {
+
+ override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = {
+ val ser = serializer.newInstance()
+ val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
+ logTrace(s"Received request: $message")
+ message match {
+ case OpenBlocks(blockIds) =>
+ val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ responseContext.onSuccess(
+ ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index b7f979dccd0f5..7576d51e22175 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,38 +17,39 @@
package org.apache.spark.network.netty
-import scala.concurrent.Future
-
import org.apache.spark.SparkConf
import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory}
+import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer}
+import org.apache.spark.network.util.{ConfigProvider, SluiceConfig}
+import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+import scala.concurrent.Future
/**
- * A [[BlockTransferService]] implementation based on Netty.
- *
- * See protocol.scala for the communication protocol between server and client
+ * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-private[spark]
-final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+ var client: SluiceClient = _
- private[this] val nettyConf: NettyConfig = new NettyConfig(conf)
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ val serializer = new JavaSerializer(conf)
- private[this] var server: BlockServer = _
- private[this] var clientFactory: BlockClientFactory = _
+ // Create a SluiceConfig using SparkConf.
+ private[this] val sluiceConf = new SluiceConfig(
+ new ConfigProvider { override def get(name: String) = conf.get(name) })
- override def init(blockDataManager: BlockDataManager): Unit = {
- server = new BlockServer(nettyConf, blockDataManager)
- clientFactory = new BlockClientFactory(nettyConf)
- }
+ private[this] var server: SluiceServer = _
+ private[this] var clientFactory: SluiceClientFactory = _
- override def close(): Unit = {
- if (server != null) {
- server.close()
- }
- if (clientFactory != null) {
- clientFactory.close()
- }
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ val streamManager = new DefaultStreamManager
+ val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
+ server = new SluiceServer(sluiceConf, streamManager, rpcHandler)
+ clientFactory = new SluiceClientFactory(sluiceConf)
}
override def fetchBlocks(
@@ -56,29 +57,21 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ
port: Int,
blockIds: Seq[String],
listener: BlockFetchingListener): Unit = {
- clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener)
+ val client = clientFactory.createClient(hostName, port)
+ new NettyBlockFetcher(serializer, client, blockIds, listener)
}
+ override def hostName: String = Utils.localHostName()
+
+ override def port: Int = server.getPort
+
+ // TODO: Implement
override def uploadBlock(
hostname: String,
port: Int,
blockId: String,
blockData: ManagedBuffer,
- level: StorageLevel): Future[Unit] = {
- clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level)
- }
+ level: StorageLevel): Future[Unit] = ???
- override def hostName: String = {
- if (server == null) {
- throw new IllegalStateException("Server has not been started")
- }
- server.hostName
- }
-
- override def port: Int = {
- if (server == null) {
- throw new IllegalStateException("Server has not been started")
- }
- server.port
- }
+ override def close(): Unit = server.close()
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala
deleted file mode 100644
index 13942f3d0adcd..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala
+++ /dev/null
@@ -1,326 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.nio.ByteBuffer
-import java.util.{List => JList}
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.channel.ChannelHandler.Sharable
-import io.netty.handler.codec._
-
-import org.apache.spark.Logging
-import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer}
-import org.apache.spark.storage.StorageLevel
-
-
-/** Messages from the client to the server. */
-private[netty]
-sealed trait ClientRequest {
- def id: Byte
-}
-
-/**
- * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can
- * correspond to multiple [[ServerResponse]]s.
- */
-private[netty]
-final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest {
- override def id = 0
-}
-
-/**
- * Request to upload a block to the server. Currently the server does not ack the upload request.
- */
-private[netty]
-final case class BlockUploadRequest(
- blockId: String,
- data: ManagedBuffer,
- level: StorageLevel)
- extends ClientRequest {
- require(blockId.length <= Byte.MaxValue)
- override def id = 1
-}
-
-
-/** Messages from server to client (usually in response to some [[ClientRequest]]. */
-private[netty]
-sealed trait ServerResponse {
- def id: Byte
-}
-
-/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */
-private[netty]
-final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse {
- require(blockId.length <= Byte.MaxValue)
- override def id = 0
-}
-
-/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */
-private[netty]
-final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse {
- require(blockId.length <= Byte.MaxValue)
- override def id = 1
-}
-
-/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */
-private[netty]
-final case class BlockUploadSuccess(blockId: String) extends ServerResponse {
- require(blockId.length <= Byte.MaxValue)
- override def id = 2
-}
-
-/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */
-private[netty]
-final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse {
- require(blockId.length <= Byte.MaxValue)
- override def id = 3
-}
-
-
-/**
- * Encoder for [[ClientRequest]] used in client side.
- *
- * This encoder is stateless so it is safe to be shared by multiple threads.
- */
-@Sharable
-private[netty]
-final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] {
- override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = {
- in match {
- case BlockFetchRequest(blocks) =>
- // 8 bytes: frame size
- // 1 byte: BlockFetchRequest vs BlockUploadRequest
- // 4 byte: num blocks
- // then for each block id write 1 byte for blockId.length and then blockId itself
- val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _)
- val buf = ctx.alloc().buffer(frameLength)
-
- buf.writeLong(frameLength)
- buf.writeByte(in.id)
- buf.writeInt(blocks.size)
- blocks.foreach { blockId =>
- ProtocolUtils.writeBlockId(buf, blockId)
- }
-
- assert(buf.writableBytes() == 0)
- out.add(buf)
-
- case BlockUploadRequest(blockId, data, level) =>
- // 8 bytes: frame size
- // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest)
- // 1 byte: blockId.length
- // data itself (length can be derived from: frame size - 1 - blockId.length)
- val headerLength = 8 + 1 + 1 + blockId.length + 5
- val frameLength = headerLength + data.size
- val header = ctx.alloc().buffer(headerLength)
-
- // Call this before we add header to out so in case of exceptions
- // we don't send anything at all.
- val body = data.convertToNetty()
-
- header.writeLong(frameLength)
- header.writeByte(in.id)
- ProtocolUtils.writeBlockId(header, blockId)
- header.writeInt(level.toInt)
- header.writeByte(level.replication)
-
- assert(header.writableBytes() == 0)
- out.add(header)
- out.add(body)
- }
- }
-}
-
-
-/**
- * Decoder in the server side to decode client requests.
- * This decoder is stateless so it is safe to be shared by multiple threads.
- *
- * This assumes the inbound messages have been processed by a frame decoder created by
- * [[ProtocolUtils.createFrameDecoder()]].
- */
-@Sharable
-private[netty]
-final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] {
- override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit =
- {
- val msgTypeId = in.readByte()
- val decoded = msgTypeId match {
- case 0 => // BlockFetchRequest
- val numBlocks = in.readInt()
- val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) }
- BlockFetchRequest(blockIds)
-
- case 1 => // BlockUploadRequest
- val blockId = ProtocolUtils.readBlockId(in)
- val level = new StorageLevel(in.readInt(), in.readByte())
-
- val ret = ByteBuffer.allocate(in.readableBytes())
- ret.put(in.nioBuffer())
- ret.flip()
- BlockUploadRequest(blockId, new NioManagedBuffer(ret), level)
- }
-
- assert(decoded.id == msgTypeId)
- out.add(decoded)
- }
-}
-
-
-/**
- * Encoder used by the server side to encode server-to-client responses.
- * This encoder is stateless so it is safe to be shared by multiple threads.
- */
-@Sharable
-private[netty]
-final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging {
- override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = {
- in match {
- case BlockFetchSuccess(blockId, data) =>
- // Handle the body first so if we encounter an error getting the body, we can respond
- // with an error instead.
- var body: AnyRef = null
- try {
- body = data.convertToNetty()
- } catch {
- case e: Exception =>
- // Re-encode this message as BlockFetchFailure.
- logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e)
- encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out)
- return
- }
-
- // If we got here, body cannot be null
- // 8 bytes = long for frame length
- // 1 byte = message id (type)
- // 1 byte = block id length
- // followed by block id itself
- val headerLength = 8 + 1 + 1 + blockId.length
- val frameLength = headerLength + data.size
- val header = ctx.alloc().buffer(headerLength)
- header.writeLong(frameLength)
- header.writeByte(in.id)
- ProtocolUtils.writeBlockId(header, blockId)
-
- assert(header.writableBytes() == 0)
- out.add(header)
- out.add(body)
-
- case BlockFetchFailure(blockId, error) =>
- val frameLength = 8 + 1 + 1 + blockId.length + error.length
- val buf = ctx.alloc().buffer(frameLength)
- buf.writeLong(frameLength)
- buf.writeByte(in.id)
- ProtocolUtils.writeBlockId(buf, blockId)
- buf.writeBytes(error.getBytes)
-
- assert(buf.writableBytes() == 0)
- out.add(buf)
-
- case BlockUploadSuccess(blockId) =>
- val frameLength = 8 + 1 + 1 + blockId.length
- val buf = ctx.alloc().buffer(frameLength)
- buf.writeLong(frameLength)
- buf.writeByte(in.id)
- ProtocolUtils.writeBlockId(buf, blockId)
-
- assert(buf.writableBytes() == 0)
- out.add(buf)
-
- case BlockUploadFailure(blockId, error) =>
- val frameLength = 8 + 1 + 1 + blockId.length + + error.length
- val buf = ctx.alloc().buffer(frameLength)
- buf.writeLong(frameLength)
- buf.writeByte(in.id)
- ProtocolUtils.writeBlockId(buf, blockId)
- buf.writeBytes(error.getBytes)
-
- assert(buf.writableBytes() == 0)
- out.add(buf)
- }
- }
-}
-
-
-/**
- * Decoder in the client side to decode server responses.
- * This decoder is stateless so it is safe to be shared by multiple threads.
- *
- * This assumes the inbound messages have been processed by a frame decoder created by
- * [[ProtocolUtils.createFrameDecoder()]].
- */
-@Sharable
-private[netty]
-final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] {
- override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = {
- val msgId = in.readByte()
- val decoded = msgId match {
- case 0 => // BlockFetchSuccess
- val blockId = ProtocolUtils.readBlockId(in)
- in.retain()
- BlockFetchSuccess(blockId, new NettyManagedBuffer(in))
-
- case 1 => // BlockFetchFailure
- val blockId = ProtocolUtils.readBlockId(in)
- val errorBytes = new Array[Byte](in.readableBytes())
- in.readBytes(errorBytes)
- BlockFetchFailure(blockId, new String(errorBytes))
-
- case 2 => // BlockUploadSuccess
- BlockUploadSuccess(ProtocolUtils.readBlockId(in))
-
- case 3 => // BlockUploadFailure
- val blockId = ProtocolUtils.readBlockId(in)
- val errorBytes = new Array[Byte](in.readableBytes())
- in.readBytes(errorBytes)
- BlockUploadFailure(blockId, new String(errorBytes))
- }
-
- assert(decoded.id == msgId)
- out.add(decoded)
- }
-}
-
-
-private[netty] object ProtocolUtils {
-
- /** LengthFieldBasedFrameDecoder used before all decoders. */
- def createFrameDecoder(): ByteToMessageDecoder = {
- // maxFrameLength = 2G
- // lengthFieldOffset = 0
- // lengthFieldLength = 8
- // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
- // initialBytesToStrip = 8, i.e. strip out the length field itself
- new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8)
- }
-
- // TODO(rxin): Make sure these work for all charsets.
- def readBlockId(in: ByteBuf): String = {
- val numBytesToRead = in.readByte().toInt
- val bytes = new Array[Byte](numBytesToRead)
- in.readBytes(bytes)
- new String(bytes)
- }
-
- def writeBlockId(out: ByteBuf, blockId: String): Unit = {
- out.writeByte(blockId.length)
- out.writeBytes(blockId.getBytes)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index e942b43d9cc4a..bce1069548437 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -19,12 +19,13 @@ package org.apache.spark.network.nio
import java.nio.ByteBuffer
-import scala.concurrent.Future
-
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+
+import scala.concurrent.Future
/**
@@ -153,12 +154,11 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Exception handling buffer message", e)
val errorMessage = Message.createBufferMessage(msg.id)
errorMessage.hasError = true
Some(errorMessage)
- }
}
case otherMessage: Any =>
@@ -174,13 +174,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
case BlockMessage.TYPE_PUT_BLOCK =>
val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + msg + "]")
- putBlock(msg.id.toString, msg.data, msg.level)
+ putBlock(msg.id, msg.data, msg.level)
None
case BlockMessage.TYPE_GET_BLOCK =>
val msg = new GetBlock(blockMessage.getId)
logDebug("Received [" + msg + "]")
- val buffer = getBlock(msg.id.toString)
+ val buffer = getBlock(msg.id)
if (buffer == null) {
return None
}
@@ -190,7 +190,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
}
- private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level)
@@ -198,7 +198,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
+ " with data size: " + bytes.limit)
}
- private def getBlock(blockId: String): ByteBuffer = {
+ private def getBlock(blockId: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
val buffer = blockDataManager.getBlockData(blockId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 439981d232349..c35aa2481ad03 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -24,14 +24,14 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.{SparkEnv, SparkConf, Logging}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index 4ab34336d3f01..6a9fa4ec65d5d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -21,7 +21,7 @@ import java.io._
import java.nio.ByteBuffer
import org.apache.spark.SparkEnv
-import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.storage._
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 63863cc0250a3..b521f0c7fc77e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -18,8 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
-
-import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index ac0599f30ef22..4d8b5c1e1b084 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,15 +17,13 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
+import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
+import scala.concurrent.{Await, Future}
import scala.util.Random
import akka.actor.{ActorSystem, Props}
@@ -35,11 +33,11 @@ import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._
-
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -215,17 +213,17 @@ private[spark] class BlockManager(
* Interface to get local block data. Throws an exception if the block cannot be found or
* cannot be read successfully.
*/
- override def getBlockData(blockId: String): ManagedBuffer = {
- val bid = BlockId(blockId)
- if (bid.isShuffle) {
- shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ if (blockId.isShuffle) {
+ shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
+ .asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
new NioManagedBuffer(buffer)
} else {
- throw new BlockNotFoundException(blockId)
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
@@ -233,8 +231,8 @@ private[spark] class BlockManager(
/**
* Put the block locally, using the given storage level.
*/
- override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(blockId, data.nioByteBuffer(), level)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
index 9ef453605f4f1..81f5f2d31dbd8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -17,5 +17,4 @@
package org.apache.spark.storage
-
class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index d095452a261db..23313fe9271fd 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -19,14 +19,13 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
+import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
import org.apache.spark.serializer.Serializer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.util.{CompletionIterator, Utils}
+import org.apache.spark.{Logging, TaskContext}
/**
@@ -228,7 +227,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val blockId = iter.next()
try {
- val buf = blockManager.getBlockData(blockId.toString)
+ val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
buf.retain()
results.put(new FetchResult(blockId, 0, buf))
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 2fc7c7d9b8312..1e35abaab5353 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -42,7 +42,7 @@ class StorageLevel private(
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
- private[spark] def this(flags: Int, replication: Int) {
+ private def this(flags: Int, replication: Int) {
this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
@@ -98,7 +98,6 @@ class StorageLevel private(
}
override def writeExternal(out: ObjectOutput) {
- /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */
out.writeByte(toInt)
out.writeByte(_replication)
}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala
deleted file mode 100644
index 2d4baafcf03d0..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import scala.concurrent.{Await, future}
-import scala.concurrent.duration._
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.SparkConf
-
-
-class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll {
-
- private val conf = new SparkConf
- private var server1: BlockServer = _
- private var server2: BlockServer = _
-
- override def beforeAll() {
- server1 = new BlockServer(new NettyConfig(conf), null)
- server2 = new BlockServer(new NettyConfig(conf), null)
- }
-
- override def afterAll() {
- if (server1 != null) {
- server1.close()
- }
- if (server2 != null) {
- server2.close()
- }
- }
-
- test("BlockClients created are active and reused") {
- val factory = new BlockClientFactory(conf)
- val c1 = factory.createClient(server1.hostName, server1.port)
- val c2 = factory.createClient(server1.hostName, server1.port)
- val c3 = factory.createClient(server2.hostName, server2.port)
- assert(c1.isActive)
- assert(c3.isActive)
- assert(c1 === c2)
- assert(c1 !== c3)
- factory.close()
- }
-
- test("never return inactive clients") {
- val factory = new BlockClientFactory(conf)
- val c1 = factory.createClient(server1.hostName, server1.port)
- c1.close()
-
- // Block until c1 is no longer active
- val f = future {
- while (c1.isActive) {
- Thread.sleep(10)
- }
- }
- Await.result(f, 3.seconds)
- assert(!c1.isActive)
-
- // Create c2, which should be different from c1
- val c2 = factory.createClient(server1.hostName, server1.port)
- assert(c1 !== c2)
- factory.close()
- }
-
- test("BlockClients are close when BlockClientFactory is stopped") {
- val factory = new BlockClientFactory(conf)
- val c1 = factory.createClient(server1.hostName, server1.port)
- val c2 = factory.createClient(server2.hostName, server2.port)
- assert(c1.isActive)
- assert(c2.isActive)
- factory.close()
- assert(!c1.isActive)
- assert(!c2.isActive)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala
deleted file mode 100644
index 4c3a649081574..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.nio.ByteBuffer
-
-import io.netty.buffer.Unpooled
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.mockito.Mockito._
-import org.mockito.Matchers.{any, eq => meq}
-
-import org.scalatest.{FunSuite, PrivateMethodTester}
-
-import org.apache.spark.network._
-
-
-class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester {
-
- /** Helper method to get num. outstanding requests from a private field using reflection. */
- private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = {
- val f = handler.getClass.getDeclaredField(
- "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches")
- f.setAccessible(true)
- f.get(handler).asInstanceOf[java.util.Map[_, _]].size
- }
-
- test("handling block data (successful fetch)") {
- val blockId = "test_block"
- val blockData = "blahblahblahblahblah"
- val handler = new BlockClientHandler
- val listener = mock(classOf[BlockFetchingListener])
- handler.addFetchRequest(blockId, listener)
- assert(sizeOfOutstandingRequests(handler) === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself
- buf.put(blockData.getBytes)
- buf.flip()
-
- channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf)))
- verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any())
- assert(sizeOfOutstandingRequests(handler) === 0)
- assert(channel.finish() === false)
- }
-
- test("handling error message (failed fetch)") {
- val blockId = "test_block"
- val handler = new BlockClientHandler
- val listener = mock(classOf[BlockFetchingListener])
- handler.addFetchRequest(blockId, listener)
- assert(sizeOfOutstandingRequests(handler) === 1)
-
- val channel = new EmbeddedChannel(handler)
- channel.writeInbound(BlockFetchFailure(blockId, "some error msg"))
- verify(listener, times(0)).onBlockFetchSuccess(any(), any())
- verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any())
- assert(sizeOfOutstandingRequests(handler) === 0)
- assert(channel.finish() === false)
- }
-
- test("clear all outstanding request upon uncaught exception") {
- val handler = new BlockClientHandler
- val listener = mock(classOf[BlockFetchingListener])
- handler.addFetchRequest("b1", listener)
- handler.addFetchRequest("b2", listener)
- handler.addFetchRequest("b3", listener)
- assert(sizeOfOutstandingRequests(handler) === 3)
-
- val channel = new EmbeddedChannel(handler)
- channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer())))
- channel.pipeline().fireExceptionCaught(new Exception("duh duh duh"))
-
- // should fail both b2 and b3
- verify(listener, times(1)).onBlockFetchSuccess(any(), any())
- verify(listener, times(2)).onBlockFetchFailure(any(), any())
- assert(sizeOfOutstandingRequests(handler) === 0)
- assert(channel.finish() === false)
- }
-
- test("clear all outstanding request upon connection close") {
- val handler = new BlockClientHandler
- val listener = mock(classOf[BlockFetchingListener])
- handler.addFetchRequest("c1", listener)
- handler.addFetchRequest("c2", listener)
- handler.addFetchRequest("c3", listener)
- assert(sizeOfOutstandingRequests(handler) === 3)
-
- val channel = new EmbeddedChannel(handler)
- channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer())))
- channel.finish()
-
- // should fail both b2 and b3
- verify(listener, times(1)).onBlockFetchSuccess(any(), any())
- verify(listener, times(2)).onBlockFetchFailure(any(), any())
- assert(sizeOfOutstandingRequests(handler) === 0)
- assert(channel.finish() === false)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala
deleted file mode 100644
index 8d1b7276f4082..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.api.java.StorageLevels
-
-
-/**
- * Test client/server encoder/decoder protocol.
- */
-class ProtocolSuite extends FunSuite {
-
- /**
- * Helper to test server to client message protocol by encoding a message and decoding it.
- */
- private def testServerToClient(msg: ServerResponse) {
- val serverChannel = new EmbeddedChannel(new ServerResponseEncoder)
- serverChannel.writeOutbound(msg)
-
- val clientChannel = new EmbeddedChannel(
- ProtocolUtils.createFrameDecoder(),
- new ServerResponseDecoder)
-
- // Drain all server outbound messages and write them to the client's server decoder.
- while (!serverChannel.outboundMessages().isEmpty) {
- clientChannel.writeInbound(serverChannel.readOutbound())
- }
-
- assert(clientChannel.inboundMessages().size === 1)
- // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is
- // overridden.
- assert(msg === clientChannel.readInbound())
- }
-
- /**
- * Helper to test client to server message protocol by encoding a message and decoding it.
- */
- private def testClientToServer(msg: ClientRequest) {
- val clientChannel = new EmbeddedChannel(new ClientRequestEncoder)
- clientChannel.writeOutbound(msg)
-
- val serverChannel = new EmbeddedChannel(
- ProtocolUtils.createFrameDecoder(),
- new ClientRequestDecoder)
-
- // Drain all client outbound messages and write them to the server's decoder.
- while (!clientChannel.outboundMessages().isEmpty) {
- serverChannel.writeInbound(clientChannel.readOutbound())
- }
-
- assert(serverChannel.inboundMessages().size === 1)
- // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is
- // overridden.
- assert(msg === serverChannel.readInbound())
- }
-
- test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") {
- testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10)))
- }
-
- test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") {
- testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0)))
- }
-
- test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") {
- testServerToClient(BlockFetchFailure("abcd", "this is an error"))
- }
-
- test("server to client protocol - BlockFetchFailure(\"\", \"\")") {
- testServerToClient(BlockFetchFailure("", ""))
- }
-
- test("client to server protocol - BlockFetchRequest(Seq.empty[String])") {
- testClientToServer(BlockFetchRequest(Seq.empty[String]))
- }
-
- test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") {
- testClientToServer(BlockFetchRequest(Seq("b1")))
- }
-
- test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") {
- testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3")))
- }
-
- test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") {
- testClientToServer(
- BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK))
- }
-
- test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") {
- testClientToServer(
- BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2))
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
deleted file mode 100644
index 35ff90a2dabc5..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-import java.util.{Collections, HashSet}
-import java.util.concurrent.{TimeUnit, Semaphore}
-
-import scala.collection.JavaConversions._
-
-import io.netty.buffer.Unpooled
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.scalatest.concurrent.Eventually._
-import org.scalatest.time.Span
-import org.scalatest.time.Seconds
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network._
-import org.apache.spark.storage.{BlockNotFoundException, StorageLevel}
-
-
-/**
-* Test cases that create real clients and servers and connect.
-*/
-class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
-
- val bufSize = 100000
- var buf: ByteBuffer = _
- var testFile: File = _
- var server: BlockServer = _
- var clientFactory: BlockClientFactory = _
-
- val bufferBlockId = "buffer_block"
- val fileBlockId = "file_block"
-
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
-
- override def beforeAll() = {
- buf = ByteBuffer.allocate(bufSize)
- for (i <- 1 to bufSize) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager {
- override def getBlockData(blockId: String): ManagedBuffer = {
- if (blockId == bufferBlockId) {
- new NioManagedBuffer(buf)
- } else if (blockId == fileBlockId) {
- new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25)
- } else {
- throw new BlockNotFoundException(blockId)
- }
- }
-
- /**
- * Put the block locally, using the given storage level.
- */
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ???
- })
-
- clientFactory = new BlockClientFactory(new SparkConf)
- }
-
- override def afterAll() = {
- server.close()
- clientFactory.close()
- }
-
- /** A ByteBuf for buffer_block */
- lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
-
- /** A ByteBuf for file_block */
- lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
-
- def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = {
- val client = clientFactory.createClient(server.hostName, server.port)
- val sem = new Semaphore(0)
- val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
- val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
- val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer])
-
- client.fetchBlocks(
- blockIds,
- new BlockFetchingListener {
- override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
- errorBlockIds.add(blockId)
- sem.release()
- }
-
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- data.retain()
- receivedBlockIds.add(blockId)
- receivedBuffers.add(data)
- sem.release()
- }
- }
- )
- if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) {
- fail("Timeout getting response from the server")
- }
- client.close()
- (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
- }
-
- test("fetch a ByteBuffer block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a FileSegment block via zero-copy send") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
- assert(blockIds === Set(fileBlockId))
- assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
- assert(blockIds.isEmpty)
- assert(buffers.isEmpty)
- assert(failBlockIds === Set("random-block"))
- buffers.foreach(_.release())
- }
-
- test("fetch both ByteBuffer block and FileSegment block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
- assert(blockIds === Set(bufferBlockId, fileBlockId))
- assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch both ByteBuffer block and a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference))
- assert(failBlockIds === Set("random-block"))
- buffers.foreach(_.release())
- }
-
- test("shutting down server should also close client") {
- val client = clientFactory.createClient(server.hostName, server.port)
- server.close()
- eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) }
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala
deleted file mode 100644
index e47e4d03fa898..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-import io.netty.buffer.Unpooled
-
-import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer}
-
-
-/**
- * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1).
- *
- * Used for testing.
- */
-class TestManagedBuffer(len: Int) extends ManagedBuffer {
-
- require(len <= Byte.MaxValue)
-
- private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte)
-
- private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray))
-
- override def size: Long = underlying.size
-
- override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty()
-
- override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer()
-
- override def inputStream(): InputStream = underlying.inputStream()
-
- override def toString: String = s"${getClass.getName}($len)"
-
- override def equals(other: Any): Boolean = other match {
- case otherBuf: ManagedBuffer =>
- val nioBuf = otherBuf.nioByteBuffer()
- if (nioBuf.remaining() != len) {
- return false
- } else {
- var i = 0
- while (i < len) {
- if (nioBuf.get() != i) {
- return false
- }
- i += 1
- }
- return true
- }
- case _ => false
- }
-
- override def retain(): this.type = this
-
- override def release(): this.type = this
-}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index ba47fe5e25b9b..6790388f96603 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockManager
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
@@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) {
assert(buffer.isInstanceOf[FileSegmentManagedBuffer])
val segment = buffer.asInstanceOf[FileSegmentManagedBuffer]
- assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath)
- assert(expected.offset === segment.offset)
- assert(expected.length === segment.length)
+ assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath)
+ assert(expected.offset === segment.getOffset)
+ assert(expected.length === segment.getLength)
}
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 7d4086313fcc1..3beb503b206f2 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -31,6 +31,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.TestSerializer
@@ -71,7 +72,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
localBlocks.foreach { case (blockId, buf) =>
- doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString))
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId))
}
// Make sure remote blocks would return
diff --git a/network/common/pom.xml b/network/common/pom.xml
new file mode 100644
index 0000000000000..e3b7e328701b4
--- /dev/null
+++ b/network/common/pom.xml
@@ -0,0 +1,94 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent
+ 1.2.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ network
+ jar
+ Shuffle Streaming Service
+ http://spark.apache.org/
+
+ network
+
+
+
+
+
+ io.netty
+ netty-all
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ com.google.guava
+ guava
+ provided
+
+
+
+
+ junit
+ junit
+ test
+
+
+ log4j
+ log4j
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+
+
+
+ target/java/classes
+ target/java/test-classes
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.17
+
+ false
+
+ **/Test*.java
+ **/*Test.java
+ **/*Suite.java
+
+
+
+
+
+
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
new file mode 100644
index 0000000000000..224f1e6c515ea
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.base.Objects;
+import com.google.common.io.ByteStreams;
+import io.netty.channel.DefaultFileRegion;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A {@link ManagedBuffer} backed by a segment in a file.
+ */
+public final class FileSegmentManagedBuffer extends ManagedBuffer {
+
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
+ private final File file;
+ private final long offset;
+ private final long length;
+
+ public FileSegmentManagedBuffer(File file, long offset, long length) {
+ this.file = file;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public long size() {
+ return length;
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ FileChannel channel = null;
+ try {
+ channel = new RandomAccessFile(file, "r").getChannel();
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ ByteBuffer buf = ByteBuffer.allocate((int) length);
+ channel.read(buf, offset);
+ buf.flip();
+ return buf;
+ } else {
+ return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
+ }
+ } catch (IOException e) {
+ try {
+ if (channel != null) {
+ long size = channel.size();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ }
+ throw new IOException("Error in opening " + this, e);
+ } finally {
+ JavaUtils.closeQuietly(channel);
+ }
+ }
+
+ @Override
+ public InputStream inputStream() throws IOException {
+ FileInputStream is = null;
+ try {
+ is = new FileInputStream(file);
+ is.skip(offset);
+ return ByteStreams.limit(is, length);
+ } catch (IOException e) {
+ try {
+ if (is != null) {
+ long size = file.length();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ } finally {
+ JavaUtils.closeQuietly(is);
+ }
+ throw new IOException("Error in opening " + this, e);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(is);
+ throw e;
+ }
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
+
+ public File getFile() { return file; }
+
+ public long getOffset() { return offset; }
+
+ public long getLength() { return length; }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("offset", offset)
+ .add("length", length)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
new file mode 100644
index 0000000000000..1735f5540c61b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - {@link FileSegmentManagedBuffer}: data backed by part of a file
+ * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer
+ * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf
+ *
+ * The concrete buffer implementation might be managed outside the JVM garbage collector.
+ * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted.
+ * In that case, if the buffer is going to be passed around to a different thread, retain/release
+ * should be called.
+ */
+public abstract class ManagedBuffer {
+
+ /** Number of bytes of the data. */
+ public abstract long size();
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ public abstract ByteBuffer nioByteBuffer() throws IOException;
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ public abstract InputStream inputStream() throws IOException;
+
+ /**
+ * Increment the reference count by one if applicable.
+ */
+ public abstract ManagedBuffer retain();
+
+ /**
+ * If applicable, decrement the reference count by one and deallocates the buffer if the
+ * reference count reaches zero.
+ */
+ public abstract ManagedBuffer release();
+
+ /**
+ * Convert the buffer into an Netty object, used to write the data out.
+ */
+ public abstract Object convertToNetty() throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
new file mode 100644
index 0000000000000..d928980423f1f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+public final class NettyManagedBuffer extends ManagedBuffer {
+ private final ByteBuf buf;
+
+ public NettyManagedBuffer(ByteBuf buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.readableBytes();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.nioBuffer();
+ }
+
+ @Override
+ public InputStream inputStream() throws IOException {
+ return new ByteBufInputStream(buf);
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ buf.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return buf;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
new file mode 100644
index 0000000000000..3953ef89fbf88
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.Unpooled;
+
+/**
+ * A {@link ManagedBuffer} backed by {@link ByteBuffer}.
+ */
+public final class NioManagedBuffer extends ManagedBuffer {
+ private final ByteBuffer buf;
+
+ public NioManagedBuffer(ByteBuffer buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.remaining();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public InputStream inputStream() throws IOException {
+ return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
+
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
new file mode 100644
index 0000000000000..40a1fe67b1c5b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * General exception caused by a remote exception while fetching a chunk.
+ */
+public class ChunkFetchFailureException extends RuntimeException {
+ private final int chunkIndex;
+
+ public ChunkFetchFailureException(int chunkIndex, String errorMsg, Throwable cause) {
+ super(errorMsg, cause);
+ this.chunkIndex = chunkIndex;
+ }
+
+ public ChunkFetchFailureException(int chunkIndex, String errorMsg) {
+ super(errorMsg);
+ this.chunkIndex = chunkIndex;
+ }
+
+ public int getChunkIndex() { return chunkIndex; }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
new file mode 100644
index 0000000000000..519e6cb470d0d
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single chunk result. For a single stream, the callbacks are
+ * guaranteed to be called by the same thread in the same order as the requests for chunks were
+ * made.
+ *
+ * Note that if a general stream failure occurs, all outstanding chunk requests may be failed.
+ */
+public interface ChunkReceivedCallback {
+ /**
+ * Called upon receipt of a particular chunk.
+ *
+ * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
+ * call returns. You must therefore either retain() the buffer or copy its contents before
+ * returning.
+ */
+ void onSuccess(int chunkIndex, ManagedBuffer buffer);
+
+ /**
+ * Called upon failure to fetch a particular chunk. Note that this may actually be called due
+ * to failure to fetch a prior chunk in this stream.
+ *
+ * After receiving a failure, the stream may or may not be valid. The client should not assume
+ * that the server's side of the stream has been closed.
+ */
+ void onFailure(int chunkIndex, Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
new file mode 100644
index 0000000000000..6ec960d795420
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * Callback for the result of a single RPC. This will be invoked once with either success or
+ * failure.
+ */
+public interface RpcResponseCallback {
+ /** Successful serialized result from server. */
+ void onSuccess(byte[] response);
+
+ /** Exception either propagated from server or raised on client side. */
+ void onFailure(Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java
new file mode 100644
index 0000000000000..1f7d3b0234e38
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClient.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.util.UUID;
+
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.request.ChunkFetchRequest;
+import org.apache.spark.network.protocol.request.RpcRequest;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of Sluice. The convenience method
+ * "sendRPC" is provided to enable control plane communication between the client and server to
+ * perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of SluiceClient using {@link SluiceClientFactory}. A single SluiceClient
+ * may be used for multiple streams, but any given stream must be restricted to a single client,
+ * in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link SluiceClientHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class SluiceClient implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(SluiceClient.class);
+
+ private final ChannelFuture cf;
+ private final SluiceClientHandler handler;
+
+ private final String serverAddr;
+
+ SluiceClient(ChannelFuture cf, SluiceClientHandler handler) {
+ this.cf = cf;
+ this.handler = handler;
+
+ if (cf != null && cf.channel() != null && cf.channel().remoteAddress() != null) {
+ serverAddr = cf.channel().remoteAddress().toString();
+ } else {
+ serverAddr = "";
+ }
+ }
+
+ public boolean isActive() {
+ return cf.channel().isActive();
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single SluiceClient
+ * is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ final int chunkIndex,
+ final ChunkReceivedCallback callback) {
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);
+
+ final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ cf.channel().writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.debug("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
+ timeTaken);
+ } else {
+ // Fail all blocks.
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ serverAddr, future.cause().getMessage());
+ logger.error(errorMsg, future.cause());
+ future.cause().printStackTrace();
+ handler.removeFetchRequest(streamChunkId);
+ callback.onFailure(chunkIndex, new RuntimeException(errorMsg));
+ }
+ }
+ });
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ */
+ public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending RPC to {}", serverAddr);
+
+ final long tag = UUID.randomUUID().getLeastSignificantBits();
+ handler.addRpcRequest(tag, callback);
+
+ cf.channel().writeAndFlush(new RpcRequest(tag, message)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.debug("Sending request {} to {} took {} ms", tag, serverAddr, timeTaken);
+ } else {
+ // Fail all blocks.
+ String errorMsg = String.format("Failed to send request %s to %s: %s", tag,
+ serverAddr, future.cause().getMessage());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(tag);
+ callback.onFailure(new RuntimeException(errorMsg));
+ }
+ }
+ });
+ }
+
+ @Override
+ public void close() {
+ cf.channel().close();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java
new file mode 100644
index 0000000000000..17491dc3f8720
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientFactory.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeoutException;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.request.ClientRequestEncoder;
+import org.apache.spark.network.protocol.response.ServerResponseDecoder;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.SluiceConfig;
+
+/**
+ * Factory for creating {@link SluiceClient}s by using createClient.
+ *
+ * The factory maintains a connection pool to other hosts and should return the same
+ * {@link SluiceClient} for the same remote host. It also shares a single worker thread pool for
+ * all {@link SluiceClient}s.
+ */
+public class SluiceClientFactory implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(SluiceClientFactory.class);
+
+ private final SluiceConfig conf;
+ private final Map connectionPool;
+ private final ClientRequestEncoder encoder;
+ private final ServerResponseDecoder decoder;
+
+ private final Class extends Channel> socketChannelClass;
+ private final EventLoopGroup workerGroup;
+
+ public SluiceClientFactory(SluiceConfig conf) {
+ this.conf = conf;
+ this.connectionPool = new ConcurrentHashMap();
+ this.encoder = new ClientRequestEncoder();
+ this.decoder = new ServerResponseDecoder();
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
+ }
+
+ /**
+ * Create a new BlockFetchingClient connecting to the given remote host / port.
+ *
+ * This blocks until a connection is successfully established.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ public SluiceClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ SluiceClient cachedClient = connectionPool.get(address);
+ if (cachedClient != null && cachedClient.isActive()) {
+ return cachedClient;
+ }
+
+ logger.debug("Creating new connection to " + address);
+
+ // There is a chance two threads are creating two different clients connecting to the same host.
+ // But that's probably ok, as long as the caller hangs on to their client for a single stream.
+ final SluiceClientHandler handler = new SluiceClientHandler();
+
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs());
+
+ // Use pooled buffers to reduce temporary buffer allocation
+ bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
+
+ bootstrap.handler(new ChannelInitializer() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ ch.pipeline()
+ .addLast("clientRequestEncoder", encoder)
+ .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast("serverResponseDecoder", decoder)
+ .addLast("handler", handler);
+ }
+ });
+
+ // Connect to the remote server
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
+ throw new TimeoutException(
+ String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ }
+
+ SluiceClient client = new SluiceClient(cf, handler);
+ connectionPool.put(address, client);
+ return client;
+ }
+
+ /** Close all connections in the connection pool, and shutdown the worker thread pool. */
+ @Override
+ public void close() {
+ for (SluiceClient client : connectionPool.values()) {
+ client.close();
+ }
+ connectionPool.clear();
+
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ }
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled because the ByteBufs are allocated by the event loop thread, but released by the
+ * executor thread rather than the event loop thread. Those thread-local caches actually delay
+ * the recycling of buffers, leading to larger memory usage.
+ */
+ private PooledByteBufAllocator createPooledByteBufAllocator() {
+ return new PooledByteBufAllocator(
+ PlatformDependent.directBufferPreferred(),
+ getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
+ getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ 0, // tinyCacheSize
+ 0, // smallCacheSize
+ 0 // normalCacheSize
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java
new file mode 100644
index 0000000000000..ed20b032931c3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/SluiceClientHandler.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.net.SocketAddress;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.response.ChunkFetchFailure;
+import org.apache.spark.network.protocol.response.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.response.RpcFailure;
+import org.apache.spark.network.protocol.response.RpcResponse;
+import org.apache.spark.network.protocol.response.ServerResponse;
+
+/**
+ * Handler that processes server responses, in response to requests issued from [[SluiceClient]].
+ * It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class SluiceClientHandler extends SimpleChannelInboundHandler {
+ private final Logger logger = LoggerFactory.getLogger(SluiceClientHandler.class);
+
+ private final Map outstandingFetches =
+ new ConcurrentHashMap();
+
+ private final Map outstandingRpcs =
+ new ConcurrentHashMap();
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long tag, RpcResponseCallback callback) {
+ outstandingRpcs.put(tag, callback);
+ }
+
+ public void removeRpcRequest(long tag) {
+ outstandingRpcs.remove(tag);
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests
+ // as well. But I guess that is ok given the caller will fail as soon as any requests fail.
+ outstandingFetches.clear();
+ }
+
+ @Override
+ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
+ if (outstandingFetches.size() > 0) {
+ SocketAddress remoteAddress = ctx.channel().remoteAddress();
+ logger.error("Still have {} requests outstanding when contention from {} is closed",
+ outstandingFetches.size(), remoteAddress);
+ failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ }
+ super.channelUnregistered(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ if (outstandingFetches.size() > 0) {
+ logger.error(String.format("Exception in connection from %s: %s",
+ ctx.channel().remoteAddress(), cause.getMessage()), cause);
+ failOutstandingRequests(cause);
+ }
+ ctx.close();
+ }
+
+ @Override
+ public void channelRead0(ChannelHandlerContext ctx, ServerResponse message) {
+ String server = ctx.channel().remoteAddress().toString();
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Got a response for block {} from {} but it is not outstanding",
+ resp.streamChunkId, server);
+ resp.buffer.release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
+ resp.buffer.release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Got a response for block {} from {} ({}) but it is not outstanding",
+ resp.streamChunkId, server, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex,
+ new ChunkFetchFailureException(resp.streamChunkId.chunkIndex, resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.tag);
+ if (listener == null) {
+ logger.warn("Got a response for RPC {} from {} ({} bytes) but it is not outstanding",
+ resp.tag, server, resp.response.length);
+ } else {
+ outstandingRpcs.remove(resp.tag);
+ listener.onSuccess(resp.response);
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.tag);
+ if (listener == null) {
+ logger.warn("Got a response for RPC {} from {} ({}) but it is not outstanding",
+ resp.tag, server, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.tag);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ }
+ }
+
+ @VisibleForTesting
+ public int numOutstandingRequests() {
+ return outstandingFetches.size();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000000000..363ea5ecfa936
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
new file mode 100644
index 0000000000000..d46a263884807
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+* Encapsulates a request for a particular chunk of a stream.
+*/
+public final class StreamChunkId implements Encodable {
+ public final long streamId;
+ public final int chunkIndex;
+
+ public StreamChunkId(long streamId, int chunkIndex) {
+ this.streamId = streamId;
+ this.chunkIndex = chunkIndex;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ public void encode(ByteBuf buffer) {
+ buffer.writeLong(streamId);
+ buffer.writeInt(chunkIndex);
+ }
+
+ public static StreamChunkId decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 8 + 4;
+ long streamId = buffer.readLong();
+ int chunkIndex = buffer.readInt();
+ return new StreamChunkId(streamId, chunkIndex);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, chunkIndex);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamChunkId) {
+ StreamChunkId o = (StreamChunkId) other;
+ return streamId == o.streamId && chunkIndex == o.chunkIndex;
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("chunkIndex", chunkIndex)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java
new file mode 100644
index 0000000000000..a79eb363cf58c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ChunkFetchRequest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol.request;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.StreamChunkId;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.response.ServerResponse} (either success or failure).
+ */
+public final class ChunkFetchRequest implements ClientRequest {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java
new file mode 100644
index 0000000000000..db075c44b4cda
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol.request;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/** Messages from the client to the server. */
+public interface ClientRequest extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /**
+ * Preceding every serialized ClientRequest is the type, which allows us to deserialize
+ * the request.
+ */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), RpcRequest(1);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 request types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch(id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return RpcRequest;
+ default: throw new IllegalArgumentException("Unknown request type: " + id);
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java
new file mode 100644
index 0000000000000..a937da4cecae0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/request/ClientRequestDecoder.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol.request;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+
+/**
+ * Decoder in the server side to decode client requests.
+ * This decoder is stateless so it is safe to be shared by multiple threads.
+ *
+ * This assumes the inbound messages have been processed by a frame decoder created by
+ * {@link org.apache.spark.network.util.NettyUtils#createFrameDecoder()}.
+ */
+@ChannelHandler.Sharable
+public final class ClientRequestDecoder extends MessageToMessageDecoder {
+
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List