Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4187] [Core] Switch to binary protocol for external shuffle service messages #3146

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
def uploadBlock(
hostname: String,
port: Int,
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
Expand Down Expand Up @@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
def uploadBlockSync(
hostname: String,
port: Int,
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.ShuffleStreamHandle
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}

object NettyMessages {
/** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
case class OpenBlocks(blockIds: Seq[BlockId])

/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
}

/**
* Serves requests to open blocks by simply registering one chunk per block requested.
* Handles opening and uploading arbitrary BlockManager blocks.
Expand All @@ -50,28 +42,29 @@ class NettyBlockRpcServer(
blockManager: BlockDataManager)
extends RpcHandler with Logging {

import NettyMessages._

private val streamManager = new OneForOneStreamManager()

override def receive(
client: TransportClient,
messageBytes: Array[Byte],
responseContext: RpcResponseCallback): Unit = {
val ser = serializer.newInstance()
val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
logTrace(s"Received request: $message")

message match {
case OpenBlocks(blockIds) =>
val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
case openBlocks: OpenBlocks =>
val blocks: Seq[ManagedBuffer] =
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(blocks.iterator)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(
ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)

case UploadBlock(blockId, blockData, level) =>
blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
case uploadBlock: UploadBlock =>
// StorageLevel is serialized as bytes using our JavaSerializer.
val level: StorageLevel =
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
responseContext.onSuccess(new Array[Byte](0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
Expand All @@ -46,6 +46,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
private[this] var appId: String = _

override def init(blockDataManager: BlockDataManager): Unit = {
val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
Expand All @@ -60,6 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer()
appId = conf.getAppId
logInfo("Server created on " + server.getPort)
}

Expand All @@ -74,8 +76,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
.start(OpenBlocks(blockIds.map(BlockId.apply)))
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
}
}

Expand All @@ -101,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
override def uploadBlock(
hostname: String,
port: Int,
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)

// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
// using our binary protocol.
val levelBytes = serializer.newInstance().serialize(level).array()

// Convert or copy nio buffer into array in order to serialize it.
val nioBuffer = blockData.nioByteBuffer()
val array = if (nioBuffer.hasArray) {
Expand All @@ -117,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
data
}

val ser = serializer.newInstance()
client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
new RpcResponseCallback {
override def onSuccess(response: Array[Byte]): Unit = {
logTrace(s"Successfully uploaded block $blockId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient}
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
Expand Down Expand Up @@ -939,7 +940,7 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
.format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMat

class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
test("security default off") {
testConnection(new SparkConf, new SparkConf) match {
val conf = new SparkConf()
.set("spark.app.id", "app-id")
testConnection(conf, conf) match {
case Success(_) => // expected
case Failure(t) => fail(t)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {

@Override
public int encodedLength() {
return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
}

@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
buf.writeInt(errorBytes.length);
buf.writeBytes(errorBytes);
Encoders.Strings.encode(buf, errorString);
}

public static ChunkFetchFailure decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
int numErrorStringBytes = buf.readInt();
byte[] errorBytes = new byte[numErrorStringBytes];
buf.readBytes(errorBytes);
return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
String errorString = Encoders.Strings.decode(buf);
return new ChunkFetchFailure(streamChunkId, errorString);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;


import com.google.common.base.Charsets;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;

/** Provides a canonical set of Encoders for simple types. */
public class Encoders {

/** Strings are encoded with their length followed by UTF-8 bytes. */
public static class Strings {
public static int encodedLength(String s) {
return 4 + s.getBytes(Charsets.UTF_8).length;
}

public static void encode(ByteBuf buf, String s) {
byte[] bytes = s.getBytes(Charsets.UTF_8);
buf.writeInt(bytes.length);
buf.writeBytes(bytes);
}

public static String decode(ByteBuf buf) {
int length = buf.readInt();
byte[] bytes = new byte[length];
buf.readBytes(bytes);
return new String(bytes, Charsets.UTF_8);
}
}

/** Byte arrays are encoded with their length followed by bytes. */
public static class ByteArrays {
public static int encodedLength(byte[] arr) {
return 4 + arr.length;
}

public static void encode(ByteBuf buf, byte[] arr) {
buf.writeInt(arr.length);
buf.writeBytes(arr);
}

public static byte[] decode(ByteBuf buf) {
int length = buf.readInt();
byte[] bytes = new byte[length];
buf.readBytes(bytes);
return bytes;
}
}

/** String arrays are encoded with the number of strings followed by per-String encoding. */
public static class StringArrays {
public static int encodedLength(String[] strings) {
int totalLength = 4;
for (String s : strings) {
totalLength += Strings.encodedLength(s);
}
return totalLength;
}

public static void encode(ByteBuf buf, String[] strings) {
buf.writeInt(strings.length);
for (String s : strings) {
Strings.encode(buf, s);
}
}

public static String[] decode(ByteBuf buf) {
int numStrings = buf.readInt();
String[] strings = new String[numStrings];
for (int i = 0; i < strings.length; i ++) {
strings[i] = Strings.decode(buf);
}
return strings;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) {

@Override
public int encodedLength() {
return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
return 8 + Encoders.Strings.encodedLength(errorString);
}

@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
buf.writeInt(errorBytes.length);
buf.writeBytes(errorBytes);
Encoders.Strings.encode(buf, errorString);
}

public static RpcFailure decode(ByteBuf buf) {
long requestId = buf.readLong();
int numErrorStringBytes = buf.readInt();
byte[] errorBytes = new byte[numErrorStringBytes];
buf.readBytes(errorBytes);
return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
String errorString = Encoders.Strings.decode(buf);
return new RpcFailure(requestId, errorString);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) {

@Override
public int encodedLength() {
return 8 + 4 + message.length;
return 8 + Encoders.ByteArrays.encodedLength(message);
}

@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
buf.writeInt(message.length);
buf.writeBytes(message);
Encoders.ByteArrays.encode(buf, message);
}

public static RpcRequest decode(ByteBuf buf) {
long requestId = buf.readLong();
int messageLen = buf.readInt();
byte[] message = new byte[messageLen];
buf.readBytes(message);
byte[] message = Encoders.ByteArrays.decode(buf);
return new RpcRequest(requestId, message);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,17 @@ public RpcResponse(long requestId, byte[] response) {
public Type type() { return Type.RpcResponse; }

@Override
public int encodedLength() { return 8 + 4 + response.length; }
public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }

@Override
public void encode(ByteBuf buf) {
buf.writeLong(requestId);
buf.writeInt(response.length);
buf.writeBytes(response);
Encoders.ByteArrays.encode(buf, response);
}

public static RpcResponse decode(ByteBuf buf) {
long requestId = buf.readLong();
int responseLen = buf.readInt();
byte[] response = new byte[responseLen];
buf.readBytes(response);
byte[] response = Encoders.ByteArrays.decode(buf);
return new RpcResponse(requestId, response);
}

Expand Down
Loading