Skip to content

Commit

Permalink
[SPARK-45545][CORE] Pass SSLOptions wherever we create a SparkTranspo…
Browse files Browse the repository at this point in the history
…rtConf

### What changes were proposed in this pull request?

This change ensures that RPC SSL options settings inheritance works properly after #43238 - we pass `sslOptions` wherever we call `fromSparkConf`.

In addition to that minor mechanical change, duplicate/add tests for every place that calls this method, to add a test case that runs with SSL support in the config.

### Why are the changes needed?

These changes are needed to ensure that the RPC SSL functionality can work properly with settings inheritance. In addition, through these tests we can ensure that any changes to these modules are also tested with SSL support and avoid regressions in the future.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Full integration testing also done as part of #42685

Added some tests and ran them:

```
build/sbt
> project core
> testOnly org.apache.spark.*Ssl*
> testOnly org.apache.spark.network.netty.NettyBlockTransferSecuritySuite
```

and

```
build/sbt -Pyarn
> project yarn
> testOnly org.apache.spark.network.yarn.SslYarnShuffleServiceWithRocksDBBackendSuite
```

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #43387 from hasnain-db/spark-tls-integrate-everywhere.

Authored-by: Hasnain Lakhani <hasnain.lakhani@databricks.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
hasnain-db authored and Mridul Muralidharan committed Oct 25, 2023
1 parent 093b7e1 commit 08c7bac
Show file tree
Hide file tree
Showing 22 changed files with 379 additions and 86 deletions.
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ private[spark] class SecurityManager(
*/
def isSslRpcEnabled(): Boolean = sslRpcEnabled

/**
* Returns the SSLOptions object for the RPC namespace
* @return the SSLOptions object for the RPC namespace
*/
def getRpcSSLOptions(): SSLOptions = rpcSSLOptions

/**
* Gets the user used for authenticating SASL connections.
* For now use a single hardcoded user.
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,12 @@ object SparkEnv extends Logging {
}

val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) {
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
val transConf = SparkTransportConf.fromSparkConf(
conf,
"shuffle",
numUsableCores,
sslOptions = Some(securityManager.getRpcSSLOptions())
)
Some(new ExternalBlockStoreClient(transConf, securityManager,
securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val registeredExecutorsDB = "registeredExecutors"

private val transportConf =
SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0)
SparkTransportConf.fromSparkConf(
sparkConf,
"shuffle",
numUsableCores = 0,
sslOptions = Some(securityManager.getRpcSSLOptions()))
private val blockHandler = newShuffleBlockHandler(transportConf)
private var transportContext: TransportContext = _

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ private[spark] class CoarseGrainedExecutorBackend(

logInfo("Connecting to driver: " + driverUrl)
try {
val shuffleClientTransportConf = SparkTransportConf.fromSparkConf(env.conf, "shuffle")
val securityManager = new SecurityManager(env.conf)
val shuffleClientTransportConf = SparkTransportConf.fromSparkConf(
env.conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions()))
if (NettyUtils.preferDirectBufs(shuffleClientTransportConf) &&
PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) {
throw new SparkException(s"Netty direct memory should at least be bigger than " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ private[spark] class NettyBlockTransferService(
val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
this.transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores)
this.transportConf = SparkTransportConf.fromSparkConf(
conf,
"shuffle",
numCores,
sslOptions = Some(securityManager.getRpcSSLOptions()))
if (authEnabled) {
serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager))
clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager))
Expand Down
10 changes: 8 additions & 2 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ private[netty] class NettyRpcEnv(
conf.clone.set(RPC_IO_NUM_CONNECTIONS_PER_PEER, 1),
"rpc",
conf.get(RPC_IO_THREADS).getOrElse(numUsableCores),
role)
role,
sslOptions = Some(securityManager.getRpcSSLOptions())
)

private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)

Expand Down Expand Up @@ -391,7 +393,11 @@ private[netty] class NettyRpcEnv(
}

val ioThreads = clone.getInt("spark.files.io.threads", 1)
val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
val downloadConf = SparkTransportConf.fromSparkConf(
clone,
module,
ioThreads,
sslOptions = Some(securityManager.getRpcSSLOptions()))
val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.nio.file.Files

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.NioBufferedFileInputStream
Expand Down Expand Up @@ -58,7 +58,11 @@ private[spark] class IndexShuffleBlockResolver(

private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)

private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
private val transportConf = {
val securityManager = new SecurityManager(conf)
SparkTransportConf.fromSparkConf(
conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions()))
}

private val remoteShuffleMaxDisk: Option[Long] =
conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.util.concurrent.ExecutorService
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import scala.util.control.NonFatal

import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.{SecurityManager, ShuffleDependency, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.annotation.Since
import org.apache.spark.executor.{CoarseGrainedExecutorBackend, ExecutorBackend}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -108,7 +108,9 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
dep: ShuffleDependency[_, _, _],
mapIndex: Int): Unit = {
val numPartitions = dep.partitioner.numPartitions
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
val securityManager = new SecurityManager(conf)
val transportConf = SparkTransportConf.fromSparkConf(
conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions()))
this.shuffleId = dep.shuffleId
this.shuffleMergeId = dep.shuffleMergeId
this.mapIndex = mapIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,8 @@ private[spark] class BlockManager(
new EncryptedBlockData(file, blockSize, conf, key))

case _ =>
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
val transportConf = SparkTransportConf.fromSparkConf(
conf, "shuffle", sslOptions = Some(securityManager.getRpcSSLOptions()))
new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
Some(managedBuffer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
var transportContext: TransportContext = _
var rpcHandler: ExternalBlockHandler = _

override def beforeAll(): Unit = {
super.beforeAll()
protected def initializeHandlers(): Unit = {
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2)
rpcHandler = new ExternalBlockHandler(transportConf, null)
transportContext = new TransportContext(transportConf, rpcHandler)
Expand All @@ -61,6 +60,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort)
}

override def beforeAll(): Unit = {
super.beforeAll()
initializeHandlers()
}

override def afterAll(): Unit = {
Utils.tryLogNonFatalError{
server.close()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.config
import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalBlockHandler

/**
* This suite creates an external shuffle server and routes all shuffle fetches through it.
* Note that failures in this suite may arise due to changes in Spark that invalidate expectations
* set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how
* we hash files into folders.
*/
class SslExternalShuffleServiceSuite extends ExternalShuffleServiceSuite {

override def initializeHandlers(): Unit = {
SslTestUtils.updateWithSSLConfig(conf)
val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf);
// Show that we can successfully inherit options defined in the `spark.ssl` namespace
val defaultSslOptions = SSLOptions.parse(conf, hadoopConf, "spark.ssl")
val sslOptions = SSLOptions.parse(
conf, hadoopConf, "spark.ssl.rpc", defaults = Some(defaultSslOptions))
val transportConf = SparkTransportConf.fromSparkConf(
conf, "shuffle", numUsableCores = 2, sslOptions = Some(sslOptions))

rpcHandler = new ExternalBlockHandler(transportConf, null)
transportContext = new TransportContext(transportConf, rpcHandler)
server = transportContext.createServer()

conf.set(config.SHUFFLE_MANAGER, "sort")
conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort)
}
}
26 changes: 26 additions & 0 deletions core/src/test/scala/org/apache/spark/SslShuffleNettySuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

class SslShuffleNettySuite extends ShuffleNettySuite {

override def beforeAll(): Unit = {
super.beforeAll()
SslTestUtils.updateWithSSLConfig(conf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite

implicit val formats = DefaultFormats

def createSparkConf(): SparkConf = {
new SparkConf()
}

test("parsing no resources") {
val conf = new SparkConf
val conf = createSparkConf()
val resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf)
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
Expand All @@ -75,7 +79,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
}

test("parsing one resource") {
val conf = new SparkConf
val conf = createSparkConf()
conf.set(EXECUTOR_GPU_ID.amountConf, "2")
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
Expand All @@ -100,11 +104,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
val ereqs = new ExecutorResourceRequests().resource(GPU, 2)
ereqs.resource(FPGA, 3)
val rp = rpBuilder.require(ereqs).build()
testParsingMultipleResources(new SparkConf, rp)
testParsingMultipleResources(createSparkConf(), rp)
}

test("parsing multiple resources") {
val conf = new SparkConf
val conf = createSparkConf()
conf.set(EXECUTOR_GPU_ID.amountConf, "2")
conf.set(EXECUTOR_FPGA_ID.amountConf, "3")
testParsingMultipleResources(conf, ResourceProfile.getOrCreateDefaultProfile(conf))
Expand Down Expand Up @@ -136,7 +140,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
}

test("error checking parsing resources and executor and task configs") {
val conf = new SparkConf
val conf = createSparkConf()
conf.set(EXECUTOR_GPU_ID.amountConf, "2")
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
Expand Down Expand Up @@ -178,11 +182,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
val ereqs = new ExecutorResourceRequests().resource(GPU, 4)
val treqs = new TaskResourceRequests().resource(GPU, 1)
val rp = rpBuilder.require(ereqs).require(treqs).build()
testExecutorResourceFoundLessThanRequired(new SparkConf, rp)
testExecutorResourceFoundLessThanRequired(createSparkConf(), rp)
}

test("executor resource found less than required") {
val conf = new SparkConf()
val conf = createSparkConf()
conf.set(EXECUTOR_GPU_ID.amountConf, "4")
conf.set(TASK_GPU_ID.amountConf, "1")
testExecutorResourceFoundLessThanRequired(conf, ResourceProfile.getOrCreateDefaultProfile(conf))
Expand Down Expand Up @@ -213,7 +217,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
}

test("use resource discovery") {
val conf = new SparkConf
val conf = createSparkConf()
conf.set(EXECUTOR_FPGA_ID.amountConf, "3")
assume(!(Utils.isWindows))
withTempDir { dir =>
Expand Down Expand Up @@ -246,7 +250,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
val ereqs = new ExecutorResourceRequests().resource(FPGA, 3, scriptPath)
ereqs.resource(GPU, 2)
val rp = rpBuilder.require(ereqs).build()
allocatedFileAndConfigsResourceDiscoveryTestFpga(dir, new SparkConf, rp)
allocatedFileAndConfigsResourceDiscoveryTestFpga(dir, createSparkConf(), rp)
}
}

Expand All @@ -255,7 +259,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
withTempDir { dir =>
val scriptPath = createTempScriptWithExpectedOutput(dir, "fpgaDiscoverScript",
"""{"name": "fpga","addresses":["f1", "f2", "f3"]}""")
val conf = new SparkConf
val conf = createSparkConf()
conf.set(EXECUTOR_FPGA_ID.amountConf, "3")
conf.set(EXECUTOR_FPGA_ID.discoveryScriptConf, scriptPath)
conf.set(EXECUTOR_GPU_ID.amountConf, "2")
Expand Down Expand Up @@ -289,7 +293,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
}

test("track allocated resources by taskId") {
val conf = new SparkConf
val conf = createSparkConf()
val securityMgr = new SecurityManager(conf)
val serializer = new JavaSerializer(conf)
var backend: CoarseGrainedExecutorBackend = null
Expand Down Expand Up @@ -389,7 +393,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
* being executed in [[Executor.TaskRunner]].
*/
test(s"Tasks launched should always be cancelled.") {
val conf = new SparkConf
val conf = createSparkConf()
val securityMgr = new SecurityManager(conf)
val serializer = new JavaSerializer(conf)
val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
Expand Down Expand Up @@ -478,7 +482,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
* it has not been launched yet.
*/
test(s"Tasks not launched should always be cancelled.") {
val conf = new SparkConf
val conf = createSparkConf()
val securityMgr = new SecurityManager(conf)
val serializer = new JavaSerializer(conf)
val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor")
Expand Down Expand Up @@ -567,7 +571,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
* [[SparkUncaughtExceptionHandler]] and [[Executor]] can exit by itself.
*/
test("SPARK-40320 Executor should exit when initialization failed for fatal error") {
val conf = new SparkConf()
val conf = createSparkConf()
.setMaster("local-cluster[1, 1, 1024]")
.set(PLUGINS, Seq(classOf[TestFatalErrorPlugin].getName))
.setAppName("test")
Expand Down Expand Up @@ -628,3 +632,11 @@ private class TestErrorExecutorPlugin extends ExecutorPlugin {
// scalastyle:on throwerror
}
}

class SslCoarseGrainedExecutorBackendSuite extends CoarseGrainedExecutorBackendSuite
with LocalSparkContext with MockitoSugar {

override def createSparkConf(): SparkConf = {
SslTestUtils.updateWithSSLConfig(super.createSparkConf())
}
}
Loading

0 comments on commit 08c7bac

Please sign in to comment.