Skip to content

Commit

Permalink
refactor serializer in scala
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Oct 24, 2014
1 parent 8d77ef2 commit eb3938d
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 134 deletions.
101 changes: 0 additions & 101 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
Expand Down Expand Up @@ -738,104 +735,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}


/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
*/
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
SerDeUtil.initialize()
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

/**
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {

def toArray(obj: Any): Array[_] = {
obj match {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}

pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].map(toArray)
} else {
Seq(toArray(obj))
}
}
}.toJavaRDD()
}

private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext(): Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
SerDeUtil.initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}
}

private
Expand Down
74 changes: 73 additions & 1 deletion core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
package org.apache.spark.api.python

import java.nio.ByteOrder
import java.util.{ArrayList => JArrayList}

import org.apache.spark.api.java.JavaRDD

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Failure
import scala.util.Try

Expand Down Expand Up @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()


/**
* Convert an RDD of Java objects to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
private[python] def toJavaArray(jrdd: JavaRDD[_]): JavaRDD[Array[_]] = {
jrdd.rdd.map {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}.toJavaRDD()
}

/**
* Choose batch size based on size of objects
*/
private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]

override def hasNext: Boolean = iter.hasNext

override def next(): Array[Byte] = {
while (iter.hasNext && buffer.length < batch) {
buffer += iter.next()
}
val bytes = pickle.dumps(buffer.toArray)
val size = bytes.length
// let 1M < size < 10M
if (size < 1024 * 1024) {
batch *= 2
} else if (size > 1024 * 1024 * 10 && batch > 1) {
batch /= 2
}
buffer.clear()
bytes
}
}

/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

/**
* Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
}
}.toJavaRDD()
}

private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
Expand Down Expand Up @@ -148,7 +220,7 @@ private[spark] object SerDeUtil extends Logging {
*/
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
new PythonRDD.AutoBatchedPickler(iter)
new SerDeUtil.AutoBatchedPickler(iter)
}
}

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,7 +1938,7 @@ def _to_java_object_rdd(self):
RDD is serialized in batch or not.
"""
rdd = self._pickled()
return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, True)
return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True)

def countApprox(self, timeout, confidence=0.95):
"""
Expand Down
45 changes: 21 additions & 24 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,6 @@ def __init__(self, sparkContext, sqlContext=None):
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
self._scala_SQLContext = sqlContext

@property
Expand Down Expand Up @@ -1124,8 +1123,7 @@ def applySchema(self, rdd, schema):
for row in rows:
_verify_type(row, schema)

batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1381,26 +1379,26 @@ class LocalHiveContext(HiveContext):
An in-process metadata data is created with data stored in ./metadata.
Warehouse data is stored in in ./warehouse.
>>> import os
>>> hiveCtx = LocalHiveContext(sc)
>>> try:
... supress = hiveCtx.sql("DROP TABLE src")
... except Exception:
... pass
>>> kv1 = os.path.join(os.environ["SPARK_HOME"],
... 'examples/src/main/resources/kv1.txt')
>>> supress = hiveCtx.sql(
... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
>>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src"
... % kv1)
>>> results = hiveCtx.sql("FROM src SELECT value"
... ).map(lambda r: int(r.value.split('_')[1]))
>>> num = results.count()
>>> reduce_sum = results.reduce(lambda x, y: x + y)
>>> num
500
>>> reduce_sum
130091
# >>> import os
# >>> hiveCtx = LocalHiveContext(sc)
# >>> try:
# ... supress = hiveCtx.sql("DROP TABLE src")
# ... except Exception:
# ... pass
# >>> kv1 = os.path.join(os.environ["SPARK_HOME"],
# ... 'examples/src/main/resources/kv1.txt')
# >>> supress = hiveCtx.sql(
# ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
# >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src"
# ... % kv1)
# >>> results = hiveCtx.sql("FROM src SELECT value"
# ... ).map(lambda r: int(r.value.split('_')[1]))
# >>> num = results.count()
# >>> reduce_sum = results.reduce(lambda x, y: x + y)
# >>> num
# 500
# >>> reduce_sum
# 130091
"""

def __init__(self, sparkContext, sqlContext=None):
Expand Down Expand Up @@ -1771,7 +1769,6 @@ def subtract(self, other, numPartitions=None):

def _test():
import doctest
from array import array
from pyspark.context import SparkContext
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
Expand Down
9 changes: 3 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import java.util.{Map => JMap, List => JList}

import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.storage.StorageLevel

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -414,12 +415,8 @@ class SchemaRDD(
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
rowToJArray(row, rowSchema)
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
}
val jrdd = this.map(rowToJArray(_, rowSchema)).toJavaRDD()
SerDeUtil.javaToPython(jrdd)
}

/**
Expand Down

0 comments on commit eb3938d

Please sign in to comment.