From eec401e4d187d65fb2e488cca72735df648cbd68 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Sep 2014 00:17:33 -0700 Subject: [PATCH] refactor, combine TransformedRDD, fix reuse PythonRDD, fix union --- python/pyspark/streaming/context.py | 7 +- python/pyspark/streaming/dstream.py | 112 ++++++++++++------ python/pyspark/streaming/tests.py | 26 ++++ python/pyspark/streaming/util.py | 31 ++++- .../streaming/api/python/PythonDStream.scala | 78 +++++++----- 5 files changed, 178 insertions(+), 76 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index c2f8c9d3ff31d..fddef0d802670 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -15,12 +15,7 @@ # limitations under the License. # -import sys -from signal import signal, SIGTERM, SIGINT -import atexit -import time - -from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer +from pyspark.serializers import UTF8Deserializer from pyspark.context import SparkContext from pyspark.streaming.dstream import DStream from pyspark.streaming.duration import Duration, Seconds diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 64088ae8e6e83..c51f39bc48428 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -15,21 +15,15 @@ # limitations under the License. # -from collections import defaultdict from itertools import chain, ifilter, imap import operator from pyspark import RDD -from pyspark.serializers import NoOpSerializer,\ - BatchedSerializer, CloudPickleSerializer, pack_long,\ - CompressedSerializer from pyspark.storagelevel import StorageLevel -from pyspark.resultiterable import ResultIterable -from pyspark.streaming.util import rddToFileName, RDDFunction -from pyspark.rdd import portable_hash, _parse_memory -from pyspark.traceback_utils import SCCallSiteSync +from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2 +from pyspark.rdd import portable_hash +from pyspark.streaming.duration import Seconds -from py4j.java_collections import ListConverter, MapConverter __all__ = ["DStream"] @@ -42,7 +36,6 @@ def __init__(self, jdstream, ssc, jrdd_deserializer): self._jrdd_deserializer = jrdd_deserializer self.is_cached = False self.is_checkpointed = False - self._partitionFunc = None def context(self): """ @@ -159,7 +152,7 @@ def foreachRDD(self, func): This is an output operator, so this DStream will be registered as an output stream and there materialized. """ - jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, t), self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc) def pyprint(self): @@ -306,19 +299,19 @@ def get_output(rdd, time): return result def transform(self, func): - return TransformedRDD(self, lambda a, b, t: func(a), cache=True) - - def transformWith(self, func, other): - return TransformedRDD(self, lambda a, b, t: func(a, b), other) + return TransformedRDD(self, lambda a, t: func(a), True) def transformWithTime(self, func): - return TransformedRDD(self, lambda a, b, t: func(a, t)) + return TransformedRDD(self, func, False) + + def transformWith(self, func, other, keepSerializer=False): + return Transformed2RDD(self, lambda a, b, t: func(a, b), other, keepSerializer) def repartitions(self, numPartitions): return self.transform(lambda rdd: rdd.repartition(numPartitions)) def union(self, other): - return self.transformWith(lambda a, b: a.union(b), other) + return self.transformWith(lambda a, b: a.union(b), other, True) def cogroup(self, other): return self.transformWith(lambda a, b: a.cogroup(b), other) @@ -329,10 +322,34 @@ def leftOuterJoin(self, other): def rightOuterJoin(self, other): return self.transformWith(lambda a, b: a.rightOuterJoin(b), other) - def slice(self, fromTime, toTime): - jrdds = self._jdstream.slice(fromTime._jtime, toTime._jtime) - # FIXME: serializer - return [RDD(jrdd, self.ctx, self.ctx.serializer) for jrdd in jrdds] + def _jtime(self, milliseconds): + return self.ctx._jvm.Time(milliseconds) + + def slice(self, begin, end): + jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) + return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] + + def window(self, windowDuration, slideDuration=None): + d = Seconds(windowDuration) + if slideDuration is None: + return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) + s = Seconds(slideDuration) + return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) + + def reduceByWindow(self, reduceFunc, inReduceFunc, windowDuration, slideDuration): + pass + + def countByWindow(self, window, slide): + pass + + def countByValueAndWindow(self, window, slide, numPartitions=None): + pass + + def groupByKeyAndWindow(self, window, slide, numPartitions=None): + pass + + def reduceByKeyAndWindow(self, reduceFunc, inReduceFunc, window, slide, numPartitions=None): + pass def updateStateByKey(self, updateFunc): # FIXME: convert updateFunc to java JFunction2 @@ -340,21 +357,44 @@ def updateStateByKey(self, updateFunc): return self._jdstream.updateStateByKey(jFunc) -# Window Operations -# TODO: implement window -# TODO: implement groupByKeyAndWindow -# TODO: implement reduceByKeyAndWindow -# TODO: implement countByValueAndWindow -# TODO: implement countByWindow -# TODO: implement reduceByWindow +class TransformedRDD(DStream): + def __init__(self, prev, func, reuse=False): + ssc = prev._ssc + self._ssc = ssc + self.ctx = ssc._sc + self._jrdd_deserializer = self.ctx.serializer + self.is_cached = False + self.is_checkpointed = False + + if isinstance(prev, TransformedRDD) and not prev.is_cached and not prev.is_checkpointed: + prev_func = prev.func + old_func = func + func = lambda rdd, t: old_func(prev_func(rdd, t), t) + reuse = reuse and prev.reuse + prev = prev.prev + self.prev = prev + self.func = func + self.reuse = reuse + self._jdstream_val = None -class TransformedRDD(DStream): - # TODO: better name for cache - def __init__(self, prev, func, other=None, cache=False): - # TODO: combine transformed RDD + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = RDDFunction(self.ctx, self.func, self.prev._jrdd_deserializer) + jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), + jfunc, self.reuse).asJavaDStream() + self._jdstream_val = jdstream + return jdstream + + +class Transformed2RDD(DStream): + def __init__(self, prev, func, other, keepSerializer=False): ssc = prev._ssc - t = RDDFunction(ssc._sc, func, prev._jrdd_deserializer) - jdstream = ssc._jvm.PythonTransformedDStream(prev._jdstream.dstream(), - other and other._jdstream, t, cache) - DStream.__init__(self, jdstream.asJavaDStream(), ssc, ssc._sc.serializer) + jfunc = RDDFunction2(ssc._sc, func, prev._jrdd_deserializer) + jdstream = ssc._jvm.PythonTransformed2DStream(prev._jdstream.dstream(), + other._jdstream.dstream(), jfunc) + jrdd_serializer = prev._jrdd_deserializer if keepSerializer else ssc._sc.serializer + DStream.__init__(self, jdstream.asJavaDStream(), ssc, jrdd_serializer) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 8b355bf6b7d79..7f9c99c047bd4 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -213,6 +213,32 @@ def add(a, b): [("a", "11"), ("b", "1"), ("", "111")]] self._test_func(input, func, expected, sort=True) + def test_union(self): + input1 = [range(3), range(5), range(1)] + input2 = [range(3, 6), range(5, 6), range(1, 6)] + + d1 = self.ssc._makeStream(input1) + d2 = self.ssc._makeStream(input2) + d = d1.union(d2) + result = d.collect() + expected = [range(6), range(6), range(6)] + + self.ssc.start() + start_time = time.time() + # Loop until get the expected the number of the result from the stream. + while True: + current_time = time.time() + # Check time out. + if (current_time - start_time) > self.timeout * 2: + break + # StreamingContext.awaitTermination is not used to wait because + # if py4j server is called every 50 milliseconds, it gets an error. + time.sleep(0.05) + # Check if the output is the same length of expected output. + if len(expected) == len(result): + break + self.assertEqual(expected, result) + def _sort_result_based_on_key(self, outputs): """Sort the list base onf first value.""" for output in outputs: diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 3047763594ce5..4051732f25302 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -28,9 +28,36 @@ def __init__(self, ctx, func, jrdd_deserializer): self.func = func self.deserializer = jrdd_deserializer - def call(self, jrdd, jrdd2, milliseconds): + def call(self, jrdd, milliseconds): try: rdd = RDD(jrdd, self.ctx, self.deserializer) + r = self.func(rdd, milliseconds) + if r: + return r._jrdd + except: + import traceback + traceback.print_exc() + + def __repr__(self): + return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func)) + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] + + +class RDDFunction2(object): + """ + This class is for py4j callback. This class is related with + org.apache.spark.streaming.api.python.PythonRDDFunction2. + """ + def __init__(self, ctx, func, jrdd_deserializer): + self.ctx = ctx + self.func = func + self.deserializer = jrdd_deserializer + + def call(self, jrdd, jrdd2, milliseconds): + try: + rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None r = self.func(rdd, other, milliseconds) if r: @@ -43,7 +70,7 @@ def __repr__(self): return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func)) class Java: - implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2'] def rddToFileName(prefix, suffix, time): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index ae5c4ae2958fa..7aab10b027c84 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -28,69 +28,83 @@ import org.apache.spark.streaming.api.java._ /** - * Interface for Python callback function + * Interface for Python callback function with two arguments */ trait PythonRDDFunction { - def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] + def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } +/** + * Interface for Python callback function with three arguments + */ +trait PythonRDDFunction2 { + def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] +} /** * Transformed DStream in Python. * * If the result RDD is PythonRDD, then it will cache it as an template for future use, * this can reduce the Python callbacks. - * - * @param parent - * @param parent2 - * @param func - * @param cache */ -class PythonTransformedDStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction, - cache: Boolean = false) +class PythonTransformedDStream (parent: DStream[_], func: PythonRDDFunction, + var reuse: Boolean = false) extends DStream[Array[Byte]] (parent.ssc) { var lastResult: PythonRDD = _ - override def dependencies = { - if (parent2 == null) { - List(parent) - } else { - List(parent, parent2) - } - } + override def dependencies = List(parent) override def slideDuration: Duration = parent.slideDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { val rdd1 = parent.getOrCompute(validTime).getOrElse(null) - val rdd2 = if (parent2 != null) parent2.getOrCompute(validTime).getOrElse(null) else null - - val r = if (rdd2 != null) { - func.call(JavaRDD.fromRDD(rdd1), JavaRDD.fromRDD(rdd2), validTime.milliseconds) - } else if (cache && lastResult != null) { - lastResult.copyTo(rdd1).asJavaRDD + if (reuse && lastResult != null) { + Some(lastResult.copyTo(rdd1)) } else { - func.call(JavaRDD.fromRDD(rdd1), null, validTime.milliseconds) - } - if (r != null) { - if (lastResult == null && r.isInstanceOf[PythonRDD]) { - lastResult = r.asInstanceOf[PythonRDD] + val r = func.call(JavaRDD.fromRDD(rdd1), validTime.milliseconds).rdd + if (reuse && lastResult == null) { + r match { + case rdd: PythonRDD => + if (rdd.parent(0) == rdd1) { + // only one PythonRDD + lastResult = rdd + } else { + // may have multiple stages + reuse = false + } + } } Some(r) - } else { - None } } val asJavaDStream = JavaDStream.fromDStream(this) } +/** + * Transformed from two DStreams in Python. + */ +class PythonTransformed2DStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction2) + extends DStream[Array[Byte]] (parent.ssc) { + + override def dependencies = List(parent, parent2) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + def resultRdd(stream: DStream[_]): JavaRDD[_] = stream.getOrCompute(validTime) match { + case Some(rdd) => JavaRDD.fromRDD(rdd) + case None => null + } + Some(func.call(resultRdd(parent), resultRdd(parent2), validTime.milliseconds)) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} /** * This is used for foreachRDD() in Python - * @param prev - * @param foreachFunction */ class PythonForeachDStream( prev: DStream[Array[Byte]], @@ -98,7 +112,7 @@ class PythonForeachDStream( ) extends ForEachDStream[Array[Byte]]( prev, (rdd: RDD[Array[Byte]], time: Time) => { - foreachFunction.call(rdd.toJavaRDD(), null, time.milliseconds) + foreachFunction.call(rdd.toJavaRDD(), time.milliseconds) } ) {