Skip to content

Commit

Permalink
Attempt to make comm. bidirectional
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Oct 17, 2014
1 parent aa58f67 commit 939f276
Show file tree
Hide file tree
Showing 43 changed files with 702 additions and 427 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ private[spark]
trait BlockFetchingListener extends EventListener {

/**
* Called once per successfully fetched block.
* Called once per successfully fetched block. After this call returns, data will be released
* automatically. If the data will be passed to another thread, the receiver should retain()
* and release() the buffer on their own, or copy the data to a new buffer.
*/
def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.network

import java.io.Closeable

import org.apache.spark.network.buffer.ManagedBuffer
import java.nio.ByteBuffer

import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils

private[spark]
Expand Down Expand Up @@ -72,7 +72,7 @@ abstract class BlockTransferService extends Closeable with Logging {
def uploadBlock(
hostname: String,
port: Int,
blockId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]

Expand All @@ -94,7 +94,10 @@ abstract class BlockTransferService extends Closeable with Logging {
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
lock.synchronized {
result = Left(data)
val ret = ByteBuffer.allocate(data.size.toInt)
ret.put(data.nioByteBuffer())
ret.flip()
result = Left(new NioManagedBuffer(ret))
lock.notify()
}
}
Expand Down Expand Up @@ -126,7 +129,7 @@ abstract class BlockTransferService extends Closeable with Logging {
def uploadBlockSync(
hostname: String,
port: Int,
blockId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer
import java.util

import org.apache.spark.Logging
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.network.BlockFetchingListener
import org.apache.spark.serializer.Serializer
import org.apache.spark.network.netty.NettyMessages._
import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, SluiceClient}
import org.apache.spark.storage.BlockId
Expand Down Expand Up @@ -52,7 +53,6 @@ class NettyBlockFetcher(
val chunkCallback = new ChunkReceivedCallback {
// On receipt of a chunk, pass it upwards as a block.
def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions {
buffer.retain()
listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@ import java.nio.ByteBuffer
import org.apache.spark.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.serializer.Serializer
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
import org.apache.spark.network.client.{SluiceClient, RpcResponseCallback}
import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.{StorageLevel, BlockId}

import scala.collection.JavaConversions._

/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
case class OpenBlocks(blockIds: Seq[BlockId])
object NettyMessages {

/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
case class OpenBlocks(blockIds: Seq[BlockId])

/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)

/** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
}

/**
* Serves requests to open blocks by simply registering one chunk per block requested.
Expand All @@ -44,16 +50,27 @@ class NettyBlockRpcServer(
blockManager: BlockDataManager)
extends RpcHandler with Logging {

override def receive(messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = {
import NettyMessages._

override def receive(
client: SluiceClient,
messageBytes: Array[Byte],
responseContext: RpcResponseCallback): Unit = {
val ser = serializer.newInstance()
val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
logTrace(s"Received request: $message")

message match {
case OpenBlocks(blockIds) =>
val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
val streamId = streamManager.registerStream(blocks.iterator)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(
ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())

case UploadBlock(blockId, blockData, level) =>
blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
responseContext.onSuccess(new Array[Byte](0))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,89 @@

package org.apache.spark.network.netty

import scala.concurrent.{Promise, Future}

import org.apache.spark.SparkConf
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{SluiceClient, SluiceClientFactory}
import org.apache.spark.network.server.{DefaultStreamManager, SluiceServer}
import org.apache.spark.network.client.{RpcResponseCallback, SluiceClient, SluiceClientFactory}
import org.apache.spark.network.netty.NettyMessages.UploadBlock
import org.apache.spark.network.server._
import org.apache.spark.network.util.{ConfigProvider, SluiceConfig}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils

import scala.concurrent.Future

/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
var client: SluiceClient = _

// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
val serializer = new JavaSerializer(conf)

// Create a SluiceConfig using SparkConf.
private[this] val sluiceConf = new SluiceConfig(
new ConfigProvider { override def get(name: String) = conf.get(name) })

private[this] var sluiceContext: SluiceContext = _
private[this] var server: SluiceServer = _
private[this] var clientFactory: SluiceClientFactory = _

override def init(blockDataManager: BlockDataManager): Unit = {
val streamManager = new DefaultStreamManager
val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
server = new SluiceServer(sluiceConf, streamManager, rpcHandler)
clientFactory = new SluiceClientFactory(sluiceConf)
sluiceContext = new SluiceContext(sluiceConf, streamManager, rpcHandler)
clientFactory = sluiceContext.createClientFactory()
server = sluiceContext.createServer()
}

override def fetchBlocks(
hostName: String,
hostname: String,
port: Int,
blockIds: Seq[String],
listener: BlockFetchingListener): Unit = {
val client = clientFactory.createClient(hostName, port)
val client = clientFactory.createClient(hostname, port)
new NettyBlockFetcher(serializer, client, blockIds, listener)
}

override def hostName: String = Utils.localHostName()

override def port: Int = server.getPort

// TODO: Implement
override def uploadBlock(
hostname: String,
port: Int,
blockId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = ???
level: StorageLevel): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)

// Convert or copy nio buffer into array in order to serialize it.
val nioBuffer = blockData.nioByteBuffer()
val array = if (nioBuffer.hasArray) {
nioBuffer.array()
} else {
val data = new Array[Byte](nioBuffer.remaining())
nioBuffer.get(data)
data
}

val ser = serializer.newInstance()
client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
new RpcResponseCallback {
override def onSuccess(response: Array[Byte]): Unit = {
logTrace(s"Successfully uploaded block $blockId")
result.success()
}
override def onFailure(e: Throwable): Unit = {
logError(s"Error while uploading block $blockId", e)
result.failure(e)
}
})

result.future
}

override def close(): Unit = server.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
blockId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
: Future[Unit] = {
checkInit()
val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
val msg = PutBlock(blockId, blockData.nioByteBuffer(), level)
val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
val remoteCmId = new ConnectionManagerId(hostName, port)
val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
Expand Down
51 changes: 48 additions & 3 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.serializer

import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
import java.io._
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkEnv
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -142,3 +142,48 @@ abstract class DeserializationStream {
}
}
}


class NoOpReadSerializer(conf: SparkConf) extends Serializer with Serializable {
override def newInstance(): SerializerInstance = {
new NoOpReadSerializerInstance()
}
}

private[spark] class NoOpReadSerializerInstance()
extends SerializerInstance {

override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
out.writeObject(t)
out.close()
ByteBuffer.wrap(bos.toByteArray)
}

override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
null.asInstanceOf[T]
}

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
null.asInstanceOf[T]
}

override def serializeStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s, 100)
}

override def deserializeStream(s: InputStream): DeserializationStream = {
new NoOpDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}

def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
new NoOpDeserializationStream(s, loader)
}
}

private[spark] class NoOpDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
def readObject[T: ClassTag](): T = throw new EOFException()
def close() { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -855,9 +855,9 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
peer.host, peer.port, blockId.toString, new NioManagedBuffer(data), tLevel)
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms"
.format((System.currentTimeMillis - onePeerStartTime)))
peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
.format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging {

val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off"

val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600)
val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000)
val akkaFailureDetector =
conf.getDouble("spark.akka.failure-detector.threshold", 300.0)
val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000)
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with Netty shuffle mode.

override def beforeAll() {
System.setProperty("spark.shuffle.use.netty", "true")
System.setProperty("spark.shuffle.blockTransferService", "netty")
}

override def afterAll() {
System.clearProperty("spark.shuffle.use.netty")
System.clearProperty("spark.shuffle.blockTransferService")
}
}
Loading

0 comments on commit 939f276

Please sign in to comment.