Skip to content

Commit

Permalink
[SPARK-3886] [PySpark] simplify serializer, use AutoBatchedSerializer…
Browse files Browse the repository at this point in the history
… by default.

This PR simplify serializer, always use batched serializer (AutoBatchedSerializer as default), even batch size is 1.

Author: Davies Liu <davies@databricks.com>

This patch had conflicts when merged, resolved by
Committer: Josh Rosen <joshrosen@databricks.com>

Closes apache#2920 from davies/fix_autobatch and squashes the following commits:

e544ef9 [Davies Liu] revert unrelated change
6880b14 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch
1d557fc [Davies Liu] fix tests
8180907 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch
76abdce [Davies Liu] clean up
53fa60b [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch
d7ac751 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch
2cc2497 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_autobatch
b4292ce [Davies Liu] fix bug in master
d79744c [Davies Liu] recover hive tests
be37ece [Davies Liu] refactor
eb3938d [Davies Liu] refactor serializer in scala
8d77ef2 [Davies Liu] simplify serializer, use AutoBatchedSerializer by default.
  • Loading branch information
Davies Liu authored and JoshRosen committed Nov 4, 2014
1 parent b671ce0 commit e4f4263
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 338 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
conf: Broadcast[SerializableWritable[Configuration]],
batchSize: Int) extends Converter[Any, Any] {
conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {

/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
Expand Down Expand Up @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
case w: Writable =>
if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
Expand Down
110 changes: 5 additions & 105 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
Expand Down Expand Up @@ -442,7 +440,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand All @@ -468,7 +466,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand All @@ -494,7 +492,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand Down Expand Up @@ -537,7 +535,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand All @@ -563,7 +561,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new WritableToJavaConverter(confBroadcasted, batchSize))
new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}

Expand Down Expand Up @@ -746,104 +744,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}


/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
*/
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
SerDeUtil.initialize()
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

/**
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {

def toArray(obj: Any): Array[_] = {
obj match {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}

pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].map(toArray)
} else {
Seq(toArray(obj))
}
}
}.toJavaRDD()
}

private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext(): Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
SerDeUtil.initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}
}

private
Expand Down
121 changes: 90 additions & 31 deletions core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
package org.apache.spark.api.python

import java.nio.ByteOrder
import java.util.{ArrayList => JArrayList}

import org.apache.spark.api.java.JavaRDD

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Failure
import scala.util.Try

Expand Down Expand Up @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()


/**
* Convert an RDD of Java objects to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
jrdd.rdd.map {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}.toJavaRDD()
}

/**
* Choose batch size based on size of objects
*/
private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext: Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}

private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
Expand Down Expand Up @@ -128,54 +200,41 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())

rdd.mapPartitions { iter =>
val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
if (batchSize > 1) {
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
if (batchSize == 0) {
new AutoBatchedPickler(cleaned)
} else {
cleaned.map(pickle.dumps(_))
val pickle = new Pickler
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}

/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
val unpickled =
if (batchSerialized) {
iter.flatMap { batch =>
unpickle.loads(batch) match {
case objs: java.util.List[_] => collectionAsScalaIterable(objs)
case other => throw new SparkException(
s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
}
}
} else {
iter.map(unpickle.loads(_))
}
unpickled.map {
case obj if isPair(obj) =>
// we only accept (K, V)
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}

val rdd = pythonToJava(pyRDD, batched).rdd
rdd.first match {
case obj if isPair(obj) =>
// we only accept (K, V)
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}
rdd.map { obj =>
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator {

// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
("1", TestWritable("test1", 123, 54.0)),
("2", TestWritable("test2", 456, 8762.3)),
("1", TestWritable("test3", 123, 423.1)),
("3", TestWritable("test56", 456, 423.5)),
("2", TestWritable("test2", 123, 5435.2))
("1", TestWritable("test1", 1, 1.0)),
("2", TestWritable("test2", 2, 2.3)),
("3", TestWritable("test3", 3, 3.1)),
("5", TestWritable("test56", 5, 5.5)),
("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
new PythonRDD.AutoBatchedPickler(iter)
new SerDeUtil.AutoBatchedPickler(iter)
}
}

Expand Down
Loading

0 comments on commit e4f4263

Please sign in to comment.