From 67a2ac87fb6e2d3fd4a5f260047a37bd2858228d Mon Sep 17 00:00:00 2001 From: schintap Date: Wed, 28 Nov 2018 11:20:55 -0500 Subject: [PATCH 1/4] [SPARK-26201] Fix python broadcast with encryption --- .../apache/spark/api/python/PythonRDD.scala | 29 ++++++++++++++++--- python/pyspark/broadcast.py | 23 +++++++++++---- python/pyspark/tests/test_broadcast.py | 14 +++++++++ 3 files changed, 56 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8b5a7a9aefea5..5ed5070558af7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial with Logging { private var encryptionServer: PythonServer[Unit] = null + private var decryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial override def handleConnection(sock: Socket): Unit = { val env = SparkEnv.get val in = sock.getInputStream() - val dir = new File(Utils.getLocalDir(env.conf)) - val file = File.createTempFile("broadcast", "", dir) - path = file.getAbsolutePath - val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + val abspath = new File(path).getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath)) DechunkedInputStream.dechunkAndCopyToOutput(in, out) } } Array(encryptionServer.port, encryptionServer.secret) } + def setupDecryptionServer(): Array[Any] = { + decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") { + override def handleConnection(sock: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream())) + Utils.tryWithSafeFinally { + val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + out.flush() + } { + JavaUtils.closeQuietly(out) + } + } + } + Array(decryptionServer.port, decryptionServer.secret) + } + + def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult() + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 1c7f2a7418df0..508d7300325e6 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -77,11 +77,12 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name - python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + self._sc = sc + self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket - port, auth_secret = python_broadcast.setupEncryptionServer() + port, auth_secret = self._python_broadcast.setupEncryptionServer() (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: @@ -89,12 +90,14 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, broadcast_out = f self.dump(value, broadcast_out) if sc._encryption_enabled: - python_broadcast.waitTillDataReceived() - self._jbroadcast = sc._jsc.broadcast(python_broadcast) + self._python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None + self._sc = None + self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file @@ -118,8 +121,16 @@ def dump(self, value, f): f.close() def load_from_path(self, path): - with open(path, 'rb', 1 << 20) as f: - return self.load(f) + # we only need to decrypt it here if its on the driver since executor + # decryption handled already + if self._sc is not None and self._sc._encryption_enabled: + port, auth_secret = self._python_broadcast.setupDecryptionServer() + (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) + self._python_broadcast.waitTillBroadcastDataSent() + return self.load(decrypted_sock_file) + else: + with open(path, 'rb', 1 << 20) as f: + return self.load(f) def load(self, file): # "file" could also be a socket diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index a98626e8f4bc9..8e587a793d4b8 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -67,6 +67,20 @@ def test_broadcast_with_encryption(self): def test_broadcast_no_encryption(self): self._test_multiple_broadcasts() + def _test_broadcast_on_driver(self, *extra_confs): + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + bs = self.sc.broadcast(value=5) + self.assertEqual(5, bs.value) + + def test_broadcast_value_driver_no_encryption(self): + self._test_broadcast_on_driver() + + def test_broadcast_value_driver_encryption(self): + self.self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) class BroadcastFrameProtocolTest(unittest.TestCase): From 605ed934fa109f7473a19d23af49e345f19392c8 Mon Sep 17 00:00:00 2001 From: schintap Date: Wed, 28 Nov 2018 14:05:03 -0500 Subject: [PATCH 2/4] Fix style and remove self --- python/pyspark/tests/test_broadcast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index 8e587a793d4b8..11d31d24bb011 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -80,7 +80,8 @@ def test_broadcast_value_driver_no_encryption(self): self._test_broadcast_on_driver() def test_broadcast_value_driver_encryption(self): - self.self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) + self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) + class BroadcastFrameProtocolTest(unittest.TestCase): From d9994b7f9b6aaaf9f87ff09b1d45f3c204f7b4d3 Mon Sep 17 00:00:00 2001 From: schintap Date: Wed, 28 Nov 2018 15:08:56 -0500 Subject: [PATCH 3/4] Move code block from load_from_path to value --- python/pyspark/broadcast.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 508d7300325e6..b8b349d593f05 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -121,16 +121,8 @@ def dump(self, value, f): f.close() def load_from_path(self, path): - # we only need to decrypt it here if its on the driver since executor - # decryption handled already - if self._sc is not None and self._sc._encryption_enabled: - port, auth_secret = self._python_broadcast.setupDecryptionServer() - (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) - self._python_broadcast.waitTillBroadcastDataSent() - return self.load(decrypted_sock_file) - else: - with open(path, 'rb', 1 << 20) as f: - return self.load(f) + with open(path, 'rb', 1 << 20) as f: + return self.load(f) def load(self, file): # "file" could also be a socket @@ -145,7 +137,15 @@ def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load_from_path(self._path) + # we only need to decrypt it here when encryption is enabled and + # if its on the driver, since executor decryption is handled already + if self._sc._encryption_enabled: + port, auth_secret = self._python_broadcast.setupDecryptionServer() + (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) + self._python_broadcast.waitTillBroadcastDataSent() + return self.load(decrypted_sock_file) + else: + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): From 001dfff1b444d3f90febe1487eac4a1411a582de Mon Sep 17 00:00:00 2001 From: schintap Date: Wed, 28 Nov 2018 15:39:24 -0500 Subject: [PATCH 4/4] add back _sc condition check --- python/pyspark/broadcast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index b8b349d593f05..29358b5740e51 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -139,7 +139,7 @@ def value(self): if not hasattr(self, "_value") and self._path is not None: # we only need to decrypt it here when encryption is enabled and # if its on the driver, since executor decryption is handled already - if self._sc._encryption_enabled: + if self._sc is not None and self._sc._encryption_enabled: port, auth_secret = self._python_broadcast.setupDecryptionServer() (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) self._python_broadcast.waitTillBroadcastDataSent()