diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index e3e6d0171e971..2760434beb309 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -60,22 +60,43 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None): instead. """ if sc is not None: - tempFile = NamedTemporaryFile(delete=False, dir=sc._temp_dir) - cPickle.dump(value, tempFile, 2) - tempFile.close() - self._path = tempFile.name + f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) + self._path = self.dump(value, f) self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) self._pickle_registry = pickle_registry else: self._jbroadcast = None self._path = path + def dump(self, value, f): + if isinstance(value, basestring): + if isinstance(value, unicode): + f.write('U') + value = value.encode('utf8') + else: + f.write('S') + f.write(value) + else: + f.write('P') + cPickle.dump(value, f, 2) + f.close() + return f.name + + def load(self, path): + with open(path, 'rb', 1 << 20) as f: + flag = f.read(1) + data = f.read() + if flag == 'P': + return cPickle.loads(data) + else: + return data.decode('utf8') if flag == 'U' else data + @property def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = cPickle.loads(open(self._path, 'rb', 4 << 20).read()) + self._value = self.load(self._path) return self._value def unpersist(self, blocking=False):