Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jan 30, 2015
1 parent e6d0427 commit 4280d04
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,19 +376,16 @@ private[spark] object PythonRDD extends Logging {
def write(obj: Any): Unit = obj match {
case null =>
dataOut.writeInt(SpecialLengths.NULL)

case arr: Array[Byte] =>
dataOut.writeInt(arr.length)
dataOut.write(arr)
case str: String =>
writeUTF(str, dataOut)

case stream: PortableDataStream =>
write(stream.toArray())
case (key, value) =>
write(key)
write(value)

case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ class PythonRDDSuite extends FunSuite {

test("Handle nulls gracefully") {
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(List("a", null).iterator, buffer)
PythonRDD.writeIteratorToStream(List(null, "a").iterator, buffer)
PythonRDD.writeIteratorToStream(List("a".getBytes, null).iterator, buffer)
PythonRDD.writeIteratorToStream(List(null, "a".getBytes).iterator, buffer)

PythonRDD.writeIteratorToStream(List((null, null), ("a", null), (null, "b")).iterator, buffer)
// Should not have NPE when write an Iterator with null in it
// The correctness will be tested in Python
PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
PythonRDD.writeIteratorToStream(
List((null, null), ("a".getBytes, null), (null, "b".getBytes)).iterator, buffer)
Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
}
}
4 changes: 3 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
Expand Down Expand Up @@ -720,6 +720,8 @@ def test_null_in_rdd(self):
jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
rdd = RDD(jrdd, self.sc, UTF8Deserializer())
self.assertEqual([u"a", None, u"b"], rdd.collect())
rdd = RDD(jrdd, self.sc, NoOpSerializer())
self.assertEqual(["a", None, "b"], rdd.collect())

def test_multiple_python_java_RDD_conversions(self):
# Regression test for SPARK-5361
Expand Down

0 comments on commit 4280d04

Please sign in to comment.