Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26201] Fix python broadcast with encryption #23166

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
with Logging {

private var encryptionServer: PythonServer[Unit] = null
private var decryptionServer: PythonServer[Unit] = null

/**
* Read data from disks, then copy it to `out`
Expand Down Expand Up @@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
override def handleConnection(sock: Socket): Unit = {
val env = SparkEnv.get
val in = sock.getInputStream()
val dir = new File(Utils.getLocalDir(env.conf))
val file = File.createTempFile("broadcast", "", dir)
path = file.getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
val abspath = new File(path).getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to make sure I understand this part -- this change isn't necessary, right? even in the old version, path gets updated here, so setupDecryptionServer would know where to read the data from.

that said, I do think your change makes more sense -- not sure why I didn't just use the supplied path in the first place.

Copy link
Author

@redsanket redsanket Nov 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the old version, we generated a random path with encryption turned off, so with encryption off it reads and writes from random path. When encryption related code was written we introduced a new "broadcast" path, the problem is when we tried to decrypt it on the driver side, it looks at the random path reference lying around and tries to decrypt from it but the actual data is in the new "broadcast" path location. So, by just passing the random path generated reference, we make sure all the places are in sync with and without encryption. (

f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I see how it was wrong before. I'm saying, after you add setupDecryptionServer, then that decryption server would still be reading from the value of path which gets updated here, since its the same object in the driver's JVM.

anyway, this isn't a big deal, I think its better with your change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I think we agree its good this way, (just to verify though I won't commit until you +1 it), but yes you are correct, now that we are using the decryption server which reads from the path in PythonBroadcast the path change isn't strictly necessary, but the value of self._path in broadcast.py doesn't match the path in PythonBroadcast so I think its better to have those match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes +1. Sorry didn't mean to get things stuck on this, just wanted to make sure I was actually following what was happening correctly.

DechunkedInputStream.dechunkAndCopyToOutput(in, out)
}
}
Array(encryptionServer.port, encryptionServer.secret)
}

def setupDecryptionServer(): Array[Any] = {
decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") {
override def handleConnection(sock: Socket): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
Utils.tryWithSafeFinally {
val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path))
Utils.tryWithSafeFinally {
Utils.copyStream(in, out, false)
} {
in.close()
}
out.flush()
} {
JavaUtils.closeQuietly(out)
}
}
}
Array(decryptionServer.port, decryptionServer.secret)
}

def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()

def waitTillDataReceived(): Unit = encryptionServer.getResult()
}
// scalastyle:on no.finalize
Expand Down
21 changes: 16 additions & 5 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,27 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
# we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
self._path = f.name
python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
self._sc = sc
self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
if sc._encryption_enabled:
# with encryption, we ask the jvm to do the encryption for us, we send it data
# over a socket
port, auth_secret = python_broadcast.setupEncryptionServer()
port, auth_secret = self._python_broadcast.setupEncryptionServer()
(encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
broadcast_out = ChunkedStream(encryption_sock_file, 8192)
else:
# no encryption, we can just write pickled data directly to the file from python
broadcast_out = f
self.dump(value, broadcast_out)
if sc._encryption_enabled:
python_broadcast.waitTillDataReceived()
self._jbroadcast = sc._jsc.broadcast(python_broadcast)
self._python_broadcast.waitTillDataReceived()
self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
self._pickle_registry = pickle_registry
else:
# we're on an executor
self._jbroadcast = None
self._sc = None
self._python_broadcast = None
if sock_file is not None:
# the jvm is doing decryption for us. Read the value
# immediately from the sock_file
Expand Down Expand Up @@ -134,7 +137,15 @@ def value(self):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self._path is not None:
self._value = self.load_from_path(self._path)
# we only need to decrypt it here when encryption is enabled and
redsanket marked this conversation as resolved.
Show resolved Hide resolved
# if its on the driver, since executor decryption is handled already
if self._sc is not None and self._sc._encryption_enabled:
port, auth_secret = self._python_broadcast.setupDecryptionServer()
(decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret)
self._python_broadcast.waitTillBroadcastDataSent()
return self.load(decrypted_sock_file)
else:
self._value = self.load_from_path(self._path)
return self._value

def unpersist(self, blocking=False):
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def test_broadcast_with_encryption(self):
def test_broadcast_no_encryption(self):
self._test_multiple_broadcasts()

def _test_broadcast_on_driver(self, *extra_confs):
conf = SparkConf()
for key, value in extra_confs:
conf.set(key, value)
conf.setMaster("local-cluster[2,1,1024]")
self.sc = SparkContext(conf=conf)
bs = self.sc.broadcast(value=5)
self.assertEqual(5, bs.value)

def test_broadcast_value_driver_no_encryption(self):
self._test_broadcast_on_driver()

def test_broadcast_value_driver_encryption(self):
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))


class BroadcastFrameProtocolTest(unittest.TestCase):

Expand Down