Skip to content

Commit

Permalink
Change blockify/unblockifyObject to accept serializer as argument.
Browse files Browse the repository at this point in the history
This makes them easier to test, since they now have no dependency on SparkEnv.
  • Loading branch information
JoshRosen committed Oct 19, 2014
1 parent 618a872 commit 33fc754
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.Random

import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.io.ByteArrayChunkOutputStream
Expand Down Expand Up @@ -86,7 +87,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
* @return number of blocks this broadcast variable is divided into
*/
private def writeBlocks(): Int = {
val blocks = TorrentBroadcast.blockifyObject(_value, blockSize, compressionCodec)
val blocks =
TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
Expand Down Expand Up @@ -164,7 +166,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
val time = (System.nanoTime() - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")

_value = TorrentBroadcast.unBlockifyObject[T](blocks, compressionCodec)
_value =
TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
Expand All @@ -180,23 +183,25 @@ private object TorrentBroadcast extends Logging {
def blockifyObject[T: ClassTag](
obj: T,
blockSize: Int,
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
val bos = new ByteArrayChunkOutputStream(blockSize)
val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
val ser = SparkEnv.get.serializer.newInstance()
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
bos.toArrays.map(ByteBuffer.wrap)
}

def unBlockifyObject[T: ClassTag](
blocks: Array[ByteBuffer],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = SparkEnv.get.serializer.newInstance()
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.apache.spark.io.SnappyCompressionCodec
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage._

class BroadcastSuite extends FunSuite with LocalSparkContext with GeneratorDrivenPropertyChecks {
Expand Down Expand Up @@ -91,14 +92,18 @@ class BroadcastSuite extends FunSuite with LocalSparkContext with GeneratorDrive
test("TorrentBroadcast's blockifyObject and unblockifyObject are inverses") {
import org.apache.spark.broadcast.TorrentBroadcast._
val blockSize = 1024
val snappy = Some(new SnappyCompressionCodec(new SparkConf()))
val conf = new SparkConf()
val compressionCodec = Some(new SnappyCompressionCodec(conf))
val serializer = new JavaSerializer(conf)
val objects = for (size <- Gen.choose(1, 1024 * 10)) yield {
val data: Array[Byte] = new Array[Byte](size)
Random.nextBytes(data)
data
}
forAll (objects) { (obj: Array[Byte]) =>
assert(unBlockifyObject[Array[Byte]](blockifyObject(obj, blockSize, snappy), snappy) === obj)
val blocks = blockifyObject(obj, blockSize, serializer, compressionCodec)
val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec)
assert(unblockified === obj)
}
}

Expand Down

0 comments on commit 33fc754

Please sign in to comment.