Skip to content

Commit

Permalink
[SPARK-3721] [PySpark] broadcast objects larger than 2G
Browse files Browse the repository at this point in the history
This patch will bring support for broadcasting objects larger than 2G.

pickle, zlib, FrameSerializer and Array[Byte] all can not support objects larger than 2G, so this patch introduce LargeObjectSerializer to serialize broadcast objects, the object will be serialized and compressed into small chunks, it also change the type of Broadcast[Array[Byte]]] into Broadcast[Array[Array[Byte]]]].

Testing for support broadcast objects larger than 2G is slow and memory hungry, so this is tested manually, could be added into SparkPerf.

Author: Davies Liu <davies@databricks.com>
Author: Davies Liu <davies.liu@gmail.com>

Closes apache#2659 from davies/huge and squashes the following commits:

7b57a14 [Davies Liu] add more tests for broadcast
28acff9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
a2f6a02 [Davies Liu] bug fix
4820613 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
5875c73 [Davies Liu] address comments
10a349b [Davies Liu] address comments
0c33016 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
6182c8f [Davies Liu] Merge branch 'master' into huge
d94b68f [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
2514848 [Davies Liu] address comments
fda395b [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
1c2d928 [Davies Liu] fix scala style
091b107 [Davies Liu] broadcast objects larger than 2G
  • Loading branch information
Davies Liu authored and JoshRosen committed Nov 19, 2014
1 parent d2e2951 commit 4a377af
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 27 deletions.
24 changes: 16 additions & 8 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[spark] class PythonRDD(
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {

Expand Down Expand Up @@ -230,8 +230,8 @@ private[spark] class PythonRDD(
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
dataOut.writeLong(broadcast.value.map(_.length.toLong).sum)
broadcast.value.foreach(dataOut.write)
oldBids.add(broadcast.id)
}
}
Expand Down Expand Up @@ -368,16 +368,24 @@ private[spark] object PythonRDD extends Logging {
}
}

def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
def readBroadcastFromFile(
sc: JavaSparkContext,
filename: String): Broadcast[Array[Array[Byte]]] = {
val size = new File(filename).length()
val file = new DataInputStream(new FileInputStream(filename))
val blockSize = 1 << 20
val n = ((size + blockSize - 1) / blockSize).toInt
val obj = new Array[Array[Byte]](n)
try {
val length = file.readInt()
val obj = new Array[Byte](length)
file.readFully(obj)
sc.broadcast(obj)
for (i <- 0 until n) {
val length = if (i < (n - 1)) blockSize else (size % blockSize).toInt
obj(i) = new Array[Byte](length)
file.readFully(obj(i))
}
} finally {
file.close()
}
sc.broadcast(obj)
}

def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""
import os

from pyspark.serializers import CompressedSerializer, PickleSerializer
from pyspark.serializers import LargeObjectSerializer


__all__ = ['Broadcast']
Expand Down Expand Up @@ -73,7 +73,7 @@ def value(self):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self.path is not None:
ser = CompressedSerializer(PickleSerializer())
ser = LargeObjectSerializer()
self._value = ser.load_stream(open(self.path)).next()
return self._value

Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer
PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
Expand Down Expand Up @@ -624,7 +624,8 @@ def broadcast(self, value):
object for reading it in distributed functions. The variable will
be sent to each cluster only once.
"""
ser = CompressedSerializer(PickleSerializer())
ser = LargeObjectSerializer()

# pass large object by py4j is very slow and need much memory
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
ser.dump_stream([value], tempFile)
Expand Down
185 changes: 178 additions & 7 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def load_stream(self, stream):

def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
if len(serialized) > (1 << 31):
raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
if self._only_write_strings:
stream.write(str(serialized))
Expand Down Expand Up @@ -446,20 +448,184 @@ def loads(self, obj):
raise ValueError("invalid sevialization type: %s" % _type)


class CompressedSerializer(FramedSerializer):
class SizeLimitedStream(object):
"""
Compress the serialized data
Read at most `limit` bytes from underlying stream
>>> from StringIO import StringIO
>>> io = StringIO()
>>> io.write("Hello world")
>>> io.seek(0)
>>> lio = SizeLimitedStream(io, 5)
>>> lio.read()
'Hello'
"""
def __init__(self, stream, limit):
self.stream = stream
self.limit = limit

def read(self, n=0):
if n > self.limit or n == 0:
n = self.limit
buf = self.stream.read(n)
self.limit -= len(buf)
return buf


class CompressedStream(object):
"""
Compress the data using zlib
>>> from StringIO import StringIO
>>> io = StringIO()
>>> wio = CompressedStream(io, 'w')
>>> wio.write("Hello world")
>>> wio.flush()
>>> io.seek(0)
>>> rio = CompressedStream(io, 'r')
>>> rio.read()
'Hello world'
>>> rio.read()
''
"""
MAX_BATCH = 1 << 20 # 1MB

def __init__(self, stream, mode='w', level=1):
self.stream = stream
self.mode = mode
if mode == 'w':
self.compresser = zlib.compressobj(level)
elif mode == 'r':
self.decompresser = zlib.decompressobj()
self.buf = ''
else:
raise ValueError("can only support mode 'w' or 'r' ")

def write(self, buf):
assert self.mode == 'w', "It's not opened for write"
if len(buf) > self.MAX_BATCH:
# zlib can not compress string larger than 2G
batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty
for i in xrange(batches):
self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH])
else:
compressed = self.compresser.compress(buf)
self.stream.write(compressed)

def flush(self, mode=zlib.Z_FULL_FLUSH):
if self.mode == 'w':
d = self.compresser.flush(mode)
self.stream.write(d)
self.stream.flush()

def close(self):
if self.mode == 'w':
self.flush(zlib.Z_FINISH)
self.stream.close()

def read(self, size=0):
assert self.mode == 'r', "It's not opened for read"
if not size:
data = self.stream.read()
result = self.decompresser.decompress(data)
last = self.decompresser.flush()
return self.buf + result + last

# fast path for small read()
if size <= len(self.buf):
result = self.buf[:size]
self.buf = self.buf[size:]
return result

result = [self.buf]
size -= len(self.buf)
self.buf = ''
while size:
need = min(size, self.MAX_BATCH)
input = self.stream.read(need)
if input:
buf = self.decompresser.decompress(input)
else:
buf = self.decompresser.flush()

if len(buf) >= size:
self.buf = buf[size:]
result.append(buf[:size])
return ''.join(result)

size -= len(buf)
result.append(buf)
if not input:
return ''.join(result)

def readline(self):
"""
This is needed for pickle, but not used in protocol 2
"""
line = []
b = self.read(1)
while b and b != '\n':
line.append(b)
b = self.read(1)
line.append(b)
return ''.join(line)


class LargeObjectSerializer(Serializer):
"""
Serialize large object which could be larger than 2G
It uses cPickle to serialize the objects
"""
def dump_stream(self, iterator, stream):
stream = CompressedStream(stream, 'w')
for value in iterator:
if isinstance(value, basestring):
if isinstance(value, unicode):
stream.write('U')
value = value.encode("utf-8")
else:
stream.write('S')
write_long(len(value), stream)
stream.write(value)
else:
stream.write('P')
cPickle.dump(value, stream, 2)
stream.flush()

def load_stream(self, stream):
stream = CompressedStream(stream, 'r')
while True:
type = stream.read(1)
if not type:
return
if type in ('S', 'U'):
length = read_long(stream)
value = stream.read(length)
if type == 'U':
value = value.decode('utf-8')
yield value
elif type == 'P':
yield cPickle.load(stream)
else:
raise ValueError("unknown type: %s" % type)


class CompressedSerializer(Serializer):
"""
Compress the serialized data
"""
def __init__(self, serializer):
FramedSerializer.__init__(self)
self.serializer = serializer

def dumps(self, obj):
return zlib.compress(self.serializer.dumps(obj), 1)
def load_stream(self, stream):
stream = CompressedStream(stream, "r")
return self.serializer.load_stream(stream)

def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
def dump_stream(self, iterator, stream):
stream = CompressedStream(stream, "w")
self.serializer.dump_stream(iterator, stream)
stream.flush()


class UTF8Deserializer(Serializer):
Expand Down Expand Up @@ -517,3 +683,8 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)


if __name__ == '__main__':
import doctest
doctest.testmod()
52 changes: 50 additions & 2 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import zipfile
import random
import threading
import hashlib

if sys.version_info[:2] <= (2, 6):
try:
Expand All @@ -47,7 +48,7 @@
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
Expand Down Expand Up @@ -236,6 +237,27 @@ def foo():
self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)

def _test_serializer(self, ser):
from StringIO import StringIO
io = StringIO()
ser.dump_stream(["abc", u"123", range(5)], io)
io.seek(0)
self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
size = io.tell()
ser.dump_stream(range(1000), io)
io.seek(0)
first = SizeLimitedStream(io, size)
self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first)))
self.assertEqual(range(1000), list(ser.load_stream(io)))

def test_compressed_serializer(self):
ser = CompressedSerializer(PickleSerializer())
self._test_serializer(ser)

def test_large_object_serializer(self):
ser = LargeObjectSerializer()
self._test_serializer(ser)


class PySparkTestCase(unittest.TestCase):

Expand Down Expand Up @@ -440,7 +462,7 @@ def test_sampling_default_seed(self):
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)

def testAggregateByKey(self):
def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)

def seqOp(x, y):
Expand Down Expand Up @@ -478,6 +500,32 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)

def test_multiple_broadcasts(self):
N = 1 << 21
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
r = range(1 << 15)
random.shuffle(r)
s = str(r)
checksum = hashlib.md5(s).hexdigest()
b2 = self.sc.broadcast(s)
r = list(set(self.sc.parallelize(range(10), 10).map(
lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
self.assertEqual(1, len(r))
size, csum = r[0]
self.assertEqual(N, size)
self.assertEqual(checksum, csum)

random.shuffle(r)
s = str(r)
checksum = hashlib.md5(s).hexdigest()
b2 = self.sc.broadcast(s)
r = list(set(self.sc.parallelize(range(10), 10).map(
lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
self.assertEqual(1, len(r))
size, csum = r[0]
self.assertEqual(N, size)
self.assertEqual(checksum, csum)

def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
Expand Down
Loading

0 comments on commit 4a377af

Please sign in to comment.