From fce0ef5ffdf7d43052978a35b238bbc4ee434cc0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 22:41:04 -0700 Subject: [PATCH] rafactor of foreachRDD() --- python/pyspark/streaming/dstream.py | 3 +- .../streaming/api/python/PythonDStream.scala | 55 ++++++++----------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index d41eca020feb1..8a9e2dab7fb07 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -142,7 +142,8 @@ def foreachRDD(self, func): stream and there materialized. """ jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer) - self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) def pprint(self): """ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index d7dd0a0c5c88b..66cf0c968478c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,9 +20,10 @@ package org.apache.spark.streaming.api.python import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ -import org.apache.spark.rdd.RDD import org.apache.spark.api.java._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} import org.apache.spark.streaming.dstream._ @@ -35,19 +36,22 @@ trait PythonRDDFunction { def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } -class RDDFunction(pfunc: PythonRDDFunction) { - def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - val jrdd = if (rdd.isDefined) { +class RDDFunction(pfunc: PythonRDDFunction) extends Serializable { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + apply(rdd, None, time) + } + + def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { + if (rdd.isDefined) { JavaRDD.fromRDD(rdd.get) } else { null } - val jrdd2 = if (rdd2.isDefined) { - JavaRDD.fromRDD(rdd2.get) - } else { - null - } - val r = pfunc.call(jrdd, jrdd2, time.milliseconds) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds) if (r != null) { Some(r.rdd) } else { @@ -66,7 +70,13 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p val asJavaDStream = JavaDStream.fromDStream(this) } -object PythonDStream { +private[spark] object PythonDStream { + + // helper function for DStream.foreachRDD(), + // cannot be `foreachRDD`, it will confusing py4j + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = { + jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds)) + } // convert list of RDD into queue of RDDs, for ssc.queueStream() def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { @@ -97,7 +107,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python if (reuse && lastResult != null) { Some(lastResult.copyTo(rdd1.get)) } else { - val r = func(rdd1, None, validTime) + val r = func(rdd1, validTime) if (reuse && r.isDefined && lastResult == null) { r.get match { case rdd: PythonRDD => @@ -206,8 +216,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // Get the RDD of the reduced value of the previous window val previousWindowRDD = getOrCompute(previousWindow.endTime) + // for small window, reduce once will be better than twice if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) { - // subtle the values from old RDDs + // subtract the values from old RDDs val oldRDDs = parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) val subbed = if (oldRDDs.size > 0) { @@ -236,22 +247,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], } } } -} - -/** - * This is used for foreachRDD() in Python - */ -class PythonForeachDStream( - prev: DStream[Array[Byte]], - foreachFunction: PythonRDDFunction - ) extends ForEachDStream[Array[Byte]]( - prev, - (rdd: RDD[Array[Byte]], time: Time) => { - if (rdd != null) { - foreachFunction.call(rdd, null, time.milliseconds) - } - } - ) { - - this.register() } \ No newline at end of file