Skip to content

Commit

Permalink
[SPARK-26201] Fix python broadcast with encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
schintap committed Nov 28, 2018
1 parent fa0d4bf commit 67a2ac8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 10 deletions.
29 changes: 25 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
with Logging {

private var encryptionServer: PythonServer[Unit] = null
private var decryptionServer: PythonServer[Unit] = null

/**
* Read data from disks, then copy it to `out`
Expand Down Expand Up @@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
override def handleConnection(sock: Socket): Unit = {
val env = SparkEnv.get
val in = sock.getInputStream()
val dir = new File(Utils.getLocalDir(env.conf))
val file = File.createTempFile("broadcast", "", dir)
path = file.getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
val abspath = new File(path).getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath))
DechunkedInputStream.dechunkAndCopyToOutput(in, out)
}
}
Array(encryptionServer.port, encryptionServer.secret)
}

def setupDecryptionServer(): Array[Any] = {
decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") {
override def handleConnection(sock: Socket): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
Utils.tryWithSafeFinally {
val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path))
Utils.tryWithSafeFinally {
Utils.copyStream(in, out, false)
} {
in.close()
}
out.flush()
} {
JavaUtils.closeQuietly(out)
}
}
}
Array(decryptionServer.port, decryptionServer.secret)
}

def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()

def waitTillDataReceived(): Unit = encryptionServer.getResult()
}
// scalastyle:on no.finalize
Expand Down
23 changes: 17 additions & 6 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,27 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
# we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
self._path = f.name
python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
self._sc = sc
self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
if sc._encryption_enabled:
# with encryption, we ask the jvm to do the encryption for us, we send it data
# over a socket
port, auth_secret = python_broadcast.setupEncryptionServer()
port, auth_secret = self._python_broadcast.setupEncryptionServer()
(encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
broadcast_out = ChunkedStream(encryption_sock_file, 8192)
else:
# no encryption, we can just write pickled data directly to the file from python
broadcast_out = f
self.dump(value, broadcast_out)
if sc._encryption_enabled:
python_broadcast.waitTillDataReceived()
self._jbroadcast = sc._jsc.broadcast(python_broadcast)
self._python_broadcast.waitTillDataReceived()
self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
self._pickle_registry = pickle_registry
else:
# we're on an executor
self._jbroadcast = None
self._sc = None
self._python_broadcast = None
if sock_file is not None:
# the jvm is doing decryption for us. Read the value
# immediately from the sock_file
Expand All @@ -118,8 +121,16 @@ def dump(self, value, f):
f.close()

def load_from_path(self, path):
with open(path, 'rb', 1 << 20) as f:
return self.load(f)
# we only need to decrypt it here if its on the driver since executor
# decryption handled already
if self._sc is not None and self._sc._encryption_enabled:
port, auth_secret = self._python_broadcast.setupDecryptionServer()
(decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret)
self._python_broadcast.waitTillBroadcastDataSent()
return self.load(decrypted_sock_file)
else:
with open(path, 'rb', 1 << 20) as f:
return self.load(f)

def load(self, file):
# "file" could also be a socket
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def test_broadcast_with_encryption(self):
def test_broadcast_no_encryption(self):
self._test_multiple_broadcasts()

def _test_broadcast_on_driver(self, *extra_confs):
conf = SparkConf()
for key, value in extra_confs:
conf.set(key, value)
conf.setMaster("local-cluster[2,1,1024]")
self.sc = SparkContext(conf=conf)
bs = self.sc.broadcast(value=5)
self.assertEqual(5, bs.value)

def test_broadcast_value_driver_no_encryption(self):
self._test_broadcast_on_driver()

def test_broadcast_value_driver_encryption(self):
self.self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))

class BroadcastFrameProtocolTest(unittest.TestCase):

Expand Down

0 comments on commit 67a2ac8

Please sign in to comment.