Skip to content

Commit

Permalink
SPARK-1255: Allow user to pass Serializer object instead of class nam…
Browse files Browse the repository at this point in the history
…e for shuffle.

This is more general than simply passing a string name and leaves more room for performance optimizations.

Note that this is technically an API breaking change in the following two ways:
1. The shuffle serializer specification in ShuffleDependency now require an object instead of a String (of the class name), but I suspect nobody else in this world has used this API other than me in GraphX and Shark.
2. Serializer's in Spark from now on are required to be serializable.

Author: Reynold Xin <rxin@apache.org>

Closes apache#149 from rxin/serializer and squashes the following commits:

5acaccd [Reynold Xin] Properly call serializer's constructors.
2a8d75a [Reynold Xin] Added more documentation for the serializer option in ShuffleDependency.
7420185 [Reynold Xin] Allow user to pass Serializer object instead of class name for shuffle.
  • Loading branch information
rxin authored and pwendell committed Mar 16, 2014
1 parent 97e4459 commit f5486e9
Show file tree
Hide file tree
Showing 18 changed files with 125 additions and 171 deletions.
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer

/**
* Base class for dependencies.
Expand All @@ -43,12 +44,13 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializerClass class name of the serializer to use
* @param serializer [[Serializer]] to use. If set to null, the default serializer, as specified
* by `spark.serializer` config option, will be used.
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializerClass: String = null)
val serializer: Serializer = null)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {

val shuffleId: Int = rdd.context.newShuffleId()
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
serializer: Serializer = SparkEnv.get.serializer): Iterator[T]

/** Stop the fetcher */
def stop() {}
Expand Down
24 changes: 14 additions & 10 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{AkkaUtils, Utils}

/**
Expand All @@ -41,7 +41,6 @@ import org.apache.spark.util.{AkkaUtils, Utils}
class SparkEnv private[spark] (
val executorId: String,
val actorSystem: ActorSystem,
val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
Expand Down Expand Up @@ -139,16 +138,22 @@ object SparkEnv extends Logging {
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
val cls = Class.forName(name, true, classLoader)
// First try with the constructor that takes SparkConf. If we can't find one,
// use a no-arg constructor instead.
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
val serializerManager = new SerializerManager

val serializer = serializerManager.setDefault(
conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf)
val serializer = instantiateClass[Serializer](
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")

val closureSerializer = serializerManager.get(
conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"),
conf)
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")

def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
Expand Down Expand Up @@ -220,7 +225,6 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
serializerManager,
serializer,
closureSerializer,
cacheManager,
Expand Down
18 changes: 9 additions & 9 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer

private[spark] sealed trait CoGroupSplitDep extends Serializable

Expand Down Expand Up @@ -66,10 +67,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]

private var serializerClass: String = null
private var serializer: Serializer = null

def setSerializer(cls: String): CoGroupedRDD[K] = {
serializerClass = cls
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
this.serializer = serializer
this
}

Expand All @@ -80,7 +81,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency[Any, Any](rdd, part, serializerClass)
new ShuffleDependency[Any, Any](rdd, part, serializer)
}
}
}
Expand Down Expand Up @@ -113,18 +114,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
// Read them from the parent
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
}
case ShuffleCoGroupSplitDep(shuffleId) => {

case ShuffleCoGroupSplitDep(shuffleId) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
rddIterators += ((it, depNum))
}
}

if (!externalSorting) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.SerializableHyperLogLog

/**
Expand Down Expand Up @@ -73,7 +74,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true,
serializerClass: String = null): RDD[(K, C)] = {
serializer: Serializer = null): RDD[(K, C)] = {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
if (getKeyClass().isArray) {
if (mapSideCombine) {
Expand All @@ -93,13 +94,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
aggregator.combineValuesByKey(iter, context)
}, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
.setSerializer(serializer)
partitioned.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
Expand Down
13 changes: 7 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer

private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
Expand All @@ -38,15 +39,15 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {

private var serializerClass: String = null
private var serializer: Serializer = null

def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
serializerClass = cls
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
this.serializer = serializer
this
}

override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency(prev, part, serializerClass))
List(new ShuffleDependency(prev, part, serializer))
}

override val partitioner = Some(part)
Expand All @@ -57,8 +58,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf))
val ser = Serializer.getSerializer(serializer)
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
}

override def clearDependencies() {
Expand Down
20 changes: 10 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.serializer.Serializer

/**
* An optimized version of cogroup for set difference/subtraction.
Expand All @@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {

private var serializerClass: String = null
private var serializer: Serializer = null

def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
serializerClass = cls
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
this.serializer = serializer
this
}

Expand All @@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency(rdd, part, serializerClass)
new ShuffleDependency(rdd, part, serializer)
}
}
}
Expand All @@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](

override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
Expand All @@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
}
case ShuffleCoGroupSplitDep(shuffleId) => {

case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context, serializer)
context, ser)
iter.foreach(op)
}
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}

Expand Down Expand Up @@ -153,7 +154,7 @@ private[spark] class ShuffleMapTask(

try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf)
val ser = Serializer.getSerializer(dep.serializer)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)

// Write the map output to its associated buckets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.util.ByteBufferInputStream

private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)
private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
extends SerializationStream {
val objOut = new ObjectOutputStream(out)
var counter = 0
val counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
private val objOut = new ObjectOutputStream(out)
private var counter = 0

/**
* Calling reset to avoid memory leak:
Expand All @@ -51,7 +50,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)

private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
Expand All @@ -60,7 +59,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}

private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerInstance {
private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
Expand All @@ -82,7 +81,7 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
}

def serializeStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s, conf)
new JavaSerializationStream(s, counterReset)
}

def deserializeStream(s: InputStream): DeserializationStream = {
Expand All @@ -97,6 +96,16 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
/**
* A Spark serializer that uses Java's built-in serialization.
*/
class JavaSerializer(conf: SparkConf) extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance(conf)
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)

def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)

override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
}

override def readExternal(in: ObjectInput) {
counterReset = in.readInt()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
private val bufferSize = {
conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
}
class KryoSerializer(conf: SparkConf)
extends org.apache.spark.serializer.Serializer
with Logging
with Serializable {

private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
private val registrator = conf.getOption("spark.kryo.registrator")

def newKryoOutput() = new KryoOutput(bufferSize)

Expand All @@ -48,7 +52,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial

// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true))
kryo.setReferences(referenceTracking)

for (cls <- KryoSerializer.toRegister) kryo.register(cls)

Expand All @@ -58,7 +62,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial

// Allow the user to register their own classes by setting spark.kryo.registrator
try {
for (regCls <- conf.getOption("spark.kryo.registrator")) {
for (regCls <- registrator) {
logDebug("Running user registrator: " + regCls)
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
Expand Down
Loading

0 comments on commit f5486e9

Please sign in to comment.