Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-44937][CORE] Add SSL/TLS support for RPC and Shuffle communications #42685

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion common/network-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@
<artifactId>netty-transport-native-kqueue</artifactId>
<classifier>osx-x86_64</classifier>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<classifier>linux-x86_64</classifier>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<classifier>linux-aarch_64</classifier>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<classifier>osx-aarch_64</classifier>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<classifier>osx-x86_64</classifier>
</dependency>
<!-- Netty End -->

<dependency>
Expand Down Expand Up @@ -147,12 +167,24 @@
<artifactId>log4j-slf4j2-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-common-utils_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
</dependency>

<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,26 @@
import java.io.Closeable;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;

import com.codahale.metrics.Counter;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.handler.codec.MessageToMessageEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.SslMessageEncoder;
import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.MessageEncoder;
import org.apache.spark.network.server.ChunkFetchRequestHandler;
Expand All @@ -45,6 +51,7 @@
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.ssl.SSLFactory;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.NettyLogger;
Expand Down Expand Up @@ -72,6 +79,8 @@ public class TransportContext implements Closeable {
private final TransportConf conf;
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;
// Non-null if SSL is enabled, null otherwise.
@Nullable private final SSLFactory sslFactory;
// Number of registered connections to the shuffle service
private Counter registeredConnections = new Counter();

Expand All @@ -87,7 +96,8 @@ public class TransportContext implements Closeable {
* RPC to load it and cause to load the non-exist matcher class again. JVM will report
* `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
*/
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
private static final MessageToMessageEncoder<Message> ENCODER = MessageEncoder.INSTANCE;
private static final MessageToMessageEncoder<Message> SSL_ENCODER = SslMessageEncoder.INSTANCE;
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;

// Separate thread pool for handling ChunkFetchRequest. This helps to enable throttling
Expand Down Expand Up @@ -125,6 +135,7 @@ public TransportContext(
this.conf = conf;
this.rpcHandler = rpcHandler;
this.closeIdleConnections = closeIdleConnections;
this.sslFactory = createSslFactory();

if (conf.getModuleName() != null &&
conf.getModuleName().equalsIgnoreCase("shuffle") &&
Expand Down Expand Up @@ -171,8 +182,12 @@ public TransportServer createServer() {
return createServer(0, new ArrayList<>());
}

public TransportChannelHandler initializePipeline(SocketChannel channel) {
return initializePipeline(channel, rpcHandler);
public TransportChannelHandler initializePipeline(SocketChannel channel, boolean isClient) {
return initializePipeline(channel, rpcHandler, isClient);
}

public boolean sslEncryptionEnabled() {
return this.sslFactory != null;
}

/**
Expand All @@ -189,15 +204,32 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) {
*/
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
RpcHandler channelRpcHandler,
boolean isClient) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
ChannelPipeline pipeline = channel.pipeline();

if (nettyLogger.getLoggingHandler() != null) {
pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
}

if (sslEncryptionEnabled()) {
SslHandler sslHandler;
try {
sslHandler = new SslHandler(
sslFactory.createSSLEngine(isClient, pipeline.channel().alloc()));
} catch (Exception e) {
throw new RuntimeException("Error creating Netty SslHandler", e);
}
pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
// Cannot use zero-copy with HTTPS, so we add in our ChunkedWriteHandler just before the
// MessageEncoder
pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
}

pipeline
.addLast("encoder", ENCODER)
.addLast("encoder", sslEncryptionEnabled()? SSL_ENCODER : ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", getDecoder())
.addLast("idleStateHandler",
Expand All @@ -223,6 +255,37 @@ protected MessageToMessageDecoder<ByteBuf> getDecoder() {
return DECODER;
}

private SSLFactory createSslFactory() {
if (conf.sslRpcEnabled()) {
if (conf.sslRpcEnabledAndKeysAreValid()) {
return new SSLFactory.Builder()
.openSslEnabled(conf.sslRpcOpenSslEnabled())
.requestedProtocol(conf.sslRpcProtocol())
.requestedCiphers(conf.sslRpcRequestedCiphers())
.keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword())
.privateKey(conf.sslRpcPrivateKey())
.keyPassword(conf.sslRpcKeyPassword())
.certChain(conf.sslRpcCertChain())
.trustStore(
conf.sslRpcTrustStore(),
conf.sslRpcTrustStorePassword(),
conf.sslRpcTrustStoreReloadingEnabled(),
conf.sslRpctrustStoreReloadIntervalMs())
.build();
} else {
if (conf.sslRpcDangerouslyFallbackIfKeysNotPresent()) {
logger.warn("SSL keys not found. Dangerously falling back to unencrypted connections");
return null;
} else {
logger.error("RPC SSL encryption enabled but keys not found!");
throw new RuntimeException("RPC SSL encryption enabled but keys not found!");
}
}
} else {
return null;
}
}

/**
* Creates the server- and client-side handler which is used to handle both RequestMessages and
* ResponseMessages. The channel is expected to have been successfully created, though certain
Expand Down Expand Up @@ -255,5 +318,8 @@ public void close() {
if (chunkFetchWorkers != null) {
chunkFetchWorkers.shutdownGracefully();
}
if (sslFactory != null) {
sslFactory.destroy();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import com.google.common.io.ByteStreams;
import io.netty.channel.DefaultFileRegion;
import io.netty.handler.stream.ChunkedStream;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;

Expand Down Expand Up @@ -137,6 +138,12 @@ public Object convertToNetty() throws IOException {
}
}

@Override
public Object convertToNettyForSsl() throws IOException {
// Cannot use zero-copy with HTTPS
return new ChunkedStream(createInputStream(), conf.sslShuffleChunkSize());
}

public File getFile() { return file; }

public long getOffset() { return offset; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,15 @@ public abstract class ManagedBuffer {
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;

/**
* Convert the buffer into a Netty object, used to write the data out with SSL encryption,
* which cannot use {@link io.netty.channel.FileRegion}.
* The return value is either a {@link io.netty.buffer.ByteBuf},
* a {@link io.netty.handler.stream.ChunkedStream}, or a {@link java.io.InputStream}.
*
* If this method returns a ByteBuf, then that buffer's reference count will be incremented and
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNettyForSsl() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public Object convertToNetty() throws IOException {
return buf.duplicate().retain();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return buf.duplicate().retain();
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public Object convertToNetty() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public Object convertToNettyForSsl() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,10 @@ public TransportResponseHandler getHandler() {

@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
// Mark the connection as timed out, so we do not return a connection that's being closed
// from the TransportClientFactory if closing takes some time (e.g. with SSL)
this.timedOut = true;
Comment on lines +328 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary for SSL because SSL connections can take a long time to close, but I think it's also okay to have it apply for regular unencrypted connections (as you've done here): if we have a client that is in the process of closing then marking it as timed out during that close helps to avoid a race condition in which a pooled client might be handed to a requester only to immediately close. 👍

// close should not take this long; use a timeout just to be safe
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -268,7 +271,7 @@ TransportClient createClient(InetSocketAddress address)
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
TransportChannelHandler clientHandler = context.initializePipeline(ch, true);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
Expand All @@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) {
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
if (context.sslEncryptionEnabled()) {
final SslHandler sslHandler = cf.channel().pipeline().get(SslHandler.class);
Future<Channel> future = sslHandler.handshakeFuture().addListener(
new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(final Future<Channel> handshakeFuture) {
if (handshakeFuture.isSuccess()) {
logger.debug("{} successfully completed TLS handshake to ", address);
} else {
if (logger.isDebugEnabled()) {
logger.debug(
"failed to complete TLS handshake to " + address,
handshakeFuture.cause());
}
cf.channel().close();
}
}
});
future.await(conf.connectionTimeoutMs());
}

TransportClient client = clientRef.get();
Channel channel = channelRef.get();
Expand Down Expand Up @@ -337,5 +360,6 @@ public void close() {
if (workerGroup != null && !workerGroup.isShuttingDown()) {
workerGroup.shutdownGracefully();
}

}
}
Loading