From 90ae568e4fe63338d60b92fe105090a67bb15f9b Mon Sep 17 00:00:00 2001 From: giwa Date: Sun, 10 Aug 2014 18:43:09 -0700 Subject: [PATCH] WIP added test case --- .../apache/spark/api/python/PythonRDD.scala | 2 - .../main/python/streaming/test_oprations.py | 25 +++++--- python/pyspark/streaming/context.py | 16 +++-- python/pyspark/streaming/dstream.py | 22 +++++-- python/pyspark/streaming_tests.py | 62 +++++++++++++++++-- python/pyspark/worker.py | 2 +- .../streaming/api/java/JavaDStreamLike.scala | 9 +++ .../streaming/api/python/PythonDStream.scala | 19 +++--- .../spark/streaming/dstream/DStream.scala | 17 +++++ 9 files changed, 134 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 668e318e7a545..b4ce4b88ca65d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -306,8 +306,6 @@ private[spark] object PythonRDD extends Logging { } catch { case eof: EOFException => {} } - println("RDDDD ==================") - println(objs) JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } diff --git a/examples/src/main/python/streaming/test_oprations.py b/examples/src/main/python/streaming/test_oprations.py index 5ee0bd4b31253..24ebe23d63166 100644 --- a/examples/src/main/python/streaming/test_oprations.py +++ b/examples/src/main/python/streaming/test_oprations.py @@ -9,15 +9,22 @@ conf = SparkConf() conf.setAppName("PythonStreamingNetworkWordCount") ssc = StreamingContext(conf=conf, duration=Seconds(1)) - ssc.checkpoint("/tmp/spark_ckp") - test_input = ssc._testInputStream([[1],[1],[1]]) -# ssc.checkpoint("/tmp/spark_ckp") - fm_test = test_input.flatMap(lambda x: x.split(" ")) - mapped_test = fm_test.map(lambda x: (x, 1)) + test_input = ssc._testInputStream([1,2,3]) + class buff: + pass + + fm_test = test_input.map(lambda x: (x, 1)) + fm_test.test_output(buff) - - mapped_test.print_() ssc.start() -# ssc.awaitTermination() -# ssc.stop() + while True: + ssc.awaitTermination(50) + try: + buff.result + break + except AttributeError: + pass + + ssc.stop() + print buff.result diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 882db547faa39..0d7665d645be8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -100,10 +100,10 @@ def awaitTermination(self, timeout=None): """ Wait for the execution to stop. """ - if timeout: - self._jssc.awaitTermination(timeout) - else: + if timeout is None: self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(timeout) # start from simple one. storageLevel is not passed for now. def socketTextStream(self, hostname, port): @@ -137,6 +137,7 @@ def stop(self, stopSparkContext=True): def checkpoint(self, directory): """ + Not tested """ self._jssc.checkpoint(directory) @@ -147,8 +148,7 @@ def _testInputStream(self, test_input, numSlices=None): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). - #tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile = open("/tmp/spark_rdd", "wb") + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(test_input): @@ -160,10 +160,8 @@ def _testInputStream(self, test_input, numSlices=None): else: serializer = self._sc._unbatched_serializer serializer.dump_stream(test_input, tempFile) - tempFile.flush() - tempFile.close() - print tempFile.name + jinput_stream = self._jvm.PythonTestInputStream(self._jssc, tempFile.name, numSlices).asJavaDStream() - return DStream(jinput_stream, self, UTF8Deserializer()) + return DStream(jinput_stream, self, PickleSerializer()) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 77c9a22239c69..47196196466db 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -47,7 +47,7 @@ def _sum(self): """ return self._mapPartitions(lambda x: [sum(x)]).reduce(operator.add) - def print_(self): + def print_(self, label=None): """ Since print is reserved name for python, we cannot define a print method function. This function prints serialized data in RDD in DStream because Scala and Java cannot @@ -56,7 +56,7 @@ def print_(self): Call DStream.print(). """ # a hack to call print function in DStream - getattr(self._jdstream, "print")() + getattr(self._jdstream, "print")(label) def filter(self, f): """ @@ -217,6 +217,7 @@ def pyprint(self): """ def takeAndPrint(rdd, time): + print "take and print ===================" taken = rdd.take(11) print "-------------------------------------------" print "Time: %s" % (str(time)) @@ -229,11 +230,24 @@ def takeAndPrint(rdd, time): self.foreachRDD(takeAndPrint) - #def transform(self, func): + #def transform(self, func): - TD # from utils import RDDFunction # wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func) # jdstream = self.ctx._jvm.PythonTransformedDStream(self._jdstream.dstream(), wrapped_func).toJavaDStream - # return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW + # return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW + + def _test_output(self, buff): + """ + This function is only for testcase. + Store data in dstream to buffer to valify the result in tesecase + """ + def get_output(rdd, time): + taken = rdd.take(11) + buff.result = taken + self.foreachRDD(get_output) + + def output(self): + self._jdstream.outputToFile() class PipelinedDStream(DStream): diff --git a/python/pyspark/streaming_tests.py b/python/pyspark/streaming_tests.py index 95c5489a5695b..0660be10b027b 100644 --- a/python/pyspark/streaming_tests.py +++ b/python/pyspark/streaming_tests.py @@ -19,12 +19,13 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. -This file will merged to tests.py. But for now, this file is separated to -focus to streaming test case +This file will merged to tests.py. But for now, this file is separated due +to focusing to streaming test case """ from fileinput import input from glob import glob +from itertools import chain import os import re import shutil @@ -41,18 +42,69 @@ SPARK_HOME = os.environ["SPARK_HOME"] +class buff: + """ + Buffer for store the output from stream + """ + result = None class PySparkStreamingTestCase(unittest.TestCase): - def setUp(self): - self._old_sys_path = list(sys.path) + print "set up" class_name = self.__class__.__name__ self.ssc = StreamingContext(appName=class_name, duration=Seconds(1)) def tearDown(self): + print "tear donw" self.ssc.stop() - sys.path = self._old_sys_path + time.sleep(10) + +class TestBasicOperationsSuite(PySparkStreamingTestCase): + def setUp(self): + PySparkStreamingTestCase.setUp(self) + buff.result = None + self.timeout = 10 # seconds + + def tearDown(self): + PySparkStreamingTestCase.tearDown(self) + + def test_map(self): + test_input = [range(1,5), range(5,9), range(9, 13)] + def test_func(dstream): + return dstream.map(lambda x: str(x)) + expected = map(str, test_input) + output = self.run_stream(test_input, test_func) + self.assertEqual(output, expected) + + def test_flatMap(self): + test_input = [range(1,5), range(5,9), range(9, 13)] + def test_func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + # Maybe there be good way to create flatmap + excepted = map(lambda x: list(chain.from_iterable((map(lambda y:[y, y*2], x)))), + test_input) + output = self.run_stream(test_input, test_func) + + def run_stream(self, test_input, test_func): + # Generate input stream with user-defined input + test_input_stream = self.ssc._testInputStream(test_input) + # Applyed test function to stream + test_stream = test_func(test_input_stream) + # Add job to get outpuf from stream + test_stream._test_output(buff) + self.ssc.start() + start_time = time.time() + while True: + current_time = time.time() + # check time out + if (current_time - start_time) > self.timeout: + self.ssc.stop() + break + self.ssc.awaitTermination(50) + if buff.result is not None: + break + return buff.result if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f43210c6c0301..7ca3252270d5a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -58,7 +58,7 @@ def main(infile, outfile): # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) + num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de4e83c1..7a002bbe74ca9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -54,6 +54,15 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.print() } + def print(label: String = null): Unit = { + dstream.print(label) + } + + def outputToFile(): Unit = { + dstream.outputToFile() + } + + /** * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 96440b15d0285..94c644fa81d45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -17,9 +17,14 @@ package org.apache.spark.streaming.api.python +import java.io._ +import java.io.{ObjectInputStream, IOException} import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -51,6 +56,8 @@ class PythonDStream[T: ClassTag]( override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { parent.getOrCompute(validTime) match{ case Some(rdd) => + logInfo("RDD ID in python DStream ===========") + logInfo("RDD id " + rdd.id) val pythonRDD = new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator) Some(pythonRDD.asJavaRDD.rdd) case None => None @@ -77,7 +84,7 @@ DStream[Array[Byte]](prev.ssc){ val pairwiseRDD = new PairwiseRDD(rdd) /* * Since python operation is executed by Scala after StreamingContext.start. - * What PairwiseDStream does is equivalent to following python code in pySpark. + * What PythonPairwiseDStream does is equivalent to python code in pySpark. * * with _JavaStackTrace(self.context) as st: * pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() @@ -142,18 +149,10 @@ class PythonTestInputStream(ssc_ : JavaStreamingContext, filename: String, numPa def compute(validTime: Time): Option[RDD[Array[Byte]]] = { logInfo("Computing RDD for time " + validTime) - //val index = ((validTime - zeroTime) / slideDuration - 1).toInt - //val selectedInput = if (index < input.size) input(index) else Seq[T]() - - // lets us test cases where RDDs are not created - //if (filename == null) - // return None - - //val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) val rdd = PythonRDD.readRDDFromFile(JavaSparkContext.fromSparkContext(ssc_.sparkContext), filename, numPartitions).rdd logInfo("Created RDD " + rdd.id + " with " + filename) Some(rdd) } val asJavaDStream = JavaDStream.fromDStream(this) -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index d8dbdf59e7ff1..bafff80adc54b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -623,6 +623,23 @@ abstract class DStream[T: ClassTag] ( new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } + + def print(label: String = null) { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val first11 = rdd.take(11) + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + if(label != null){ + println (label) + } + first11.take(10).foreach(println) + if (first11.size > 10) println("...") + println() + } + new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() + } + /** * Return a new DStream in which each RDD contains all the elements in seen in a * sliding window of time over this DStream. The new DStream generates RDDs with