diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 489a826b6ec91..d431ea72da758 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.io.ByteArrayChunkOutputStream @@ -86,7 +87,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) * @return number of blocks this broadcast variable is divided into */ private def writeBlocks(): Int = { - val blocks = TorrentBroadcast.blockifyObject(_value, blockSize, compressionCodec) + val blocks = + TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => SparkEnv.get.blockManager.putBytes( BroadcastBlockId(id, "piece" + i), @@ -164,7 +166,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) val time = (System.nanoTime() - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") - _value = TorrentBroadcast.unBlockifyObject[T](blocks, compressionCodec) + _value = + TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( @@ -180,10 +183,11 @@ private object TorrentBroadcast extends Logging { def blockifyObject[T: ClassTag]( obj: T, blockSize: Int, + serializer: Serializer, compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { val bos = new ByteArrayChunkOutputStream(blockSize) val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos) - val ser = SparkEnv.get.serializer.newInstance() + val ser = serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() bos.toArrays.map(ByteBuffer.wrap) @@ -191,12 +195,13 @@ private object TorrentBroadcast extends Logging { def unBlockifyObject[T: ClassTag]( blocks: Array[ByteBuffer], + serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) - val ser = SparkEnv.get.serializer.newInstance() + val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) val obj = serIn.readObject[T]() serIn.close() diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 13ce6d99d4e7a..73f9b59de4b78 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.apache.spark.io.SnappyCompressionCodec +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ class BroadcastSuite extends FunSuite with LocalSparkContext with GeneratorDrivenPropertyChecks { @@ -91,14 +92,18 @@ class BroadcastSuite extends FunSuite with LocalSparkContext with GeneratorDrive test("TorrentBroadcast's blockifyObject and unblockifyObject are inverses") { import org.apache.spark.broadcast.TorrentBroadcast._ val blockSize = 1024 - val snappy = Some(new SnappyCompressionCodec(new SparkConf())) + val conf = new SparkConf() + val compressionCodec = Some(new SnappyCompressionCodec(conf)) + val serializer = new JavaSerializer(conf) val objects = for (size <- Gen.choose(1, 1024 * 10)) yield { val data: Array[Byte] = new Array[Byte](size) Random.nextBytes(data) data } forAll (objects) { (obj: Array[Byte]) => - assert(unBlockifyObject[Array[Byte]](blockifyObject(obj, blockSize, snappy), snappy) === obj) + val blocks = blockifyObject(obj, blockSize, serializer, compressionCodec) + val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) + assert(unblockified === obj) } }