From e108ec114eb1a14c6e2387761da8e55bee4b3c83 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 22:51:18 -0700 Subject: [PATCH] address comments --- .../apache/spark/api/python/PythonRDD.scala | 8 -- python/pyspark/rdd.py | 2 + python/pyspark/streaming/context.py | 38 +++--- python/pyspark/streaming/dstream.py | 112 +++++++++--------- python/pyspark/streaming/tests.py | 4 +- .../streaming/api/python/PythonDStream.scala | 2 +- 6 files changed, 80 insertions(+), 86 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fd6e3406a3b7e..f36a651dc2d8f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -52,12 +50,6 @@ private[spark] class PythonRDD( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { - // create a new PythonRDD with same Python setting but different parent. - def copyTo(rdd: RDD[_]): PythonRDD = { - new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, - pythonExec, broadcastVars, accumulator) - } - val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index dc6497772e502..77e8fb1773fd1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -787,6 +787,8 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ + if not self.getNumPartitions(): + return 0 # empty RDD can not been reduced return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) def count(self): diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 7f99d38771ce8..dc9dc41121935 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -84,17 +84,18 @@ class StreamingContext(object): """ _transformerSerializer = None - def __init__(self, sparkContext, duration=None, jssc=None): + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @param sparkContext: L{SparkContext} object. - @param duration: number of seconds. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches """ self._sc = sparkContext self._jvm = self._sc._jvm - self._jssc = jssc or self._initialize_context(self._sc, duration) + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) def _initialize_context(self, sc, duration): self._ensure_initialized() @@ -134,26 +135,27 @@ def _ensure_initialized(cls): SparkContext._active_spark_context, CloudPickleSerializer(), gw) @classmethod - def getOrCreate(cls, path, setupFunc): + def getOrCreate(cls, checkpointPath, setupFunc): """ - Get the StreamingContext from checkpoint file at `path`, or setup - it by `setupFunc`. + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. - :param path: directory of checkpoint - :param setupFunc: a function used to create StreamingContext and - setup DStreams. - :return: a StreamingContext + @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc Function to create a new JavaStreamingContext and setup DStreams """ - if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path): + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): ssc = setupFunc() - ssc.checkpoint(path) + ssc.checkpoint(checkpointPath) return ssc cls._ensure_initialized() gw = SparkContext._gateway try: - jssc = gw.jvm.JavaStreamingContext(path) + jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: print >>sys.stderr, "failed to load StreamingContext from checkpoint" raise @@ -249,12 +251,12 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) - def _check_serialzers(self, rdds): + def _check_serializers(self, rdds): # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: for i in range(len(rdds)): # reset them to sc.serializer - rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) + rdds[i] = rdds[i]._reserialize() def queueStream(self, rdds, oneAtATime=True, default=None): """ @@ -275,7 +277,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): if rdds and not isinstance(rdds[0], RDD): rdds = [self._sc.parallelize(input) for input in rdds] - self._check_serialzers(rdds) + self._check_serializers(rdds) jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) @@ -313,6 +315,10 @@ def union(self, *dstreams): raise ValueError("should have at least one DStream to union") if len(dstreams) == 1: return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") first = dstreams[0] jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], SparkContext._gateway._gateway_client) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index fddfd757b8674..824131739cce3 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -53,7 +53,7 @@ class DStream(object): def __init__(self, jdstream, ssc, jrdd_deserializer): self._jdstream = jdstream self._ssc = ssc - self.ctx = ssc._sc + self._sc = ssc._sc self._jrdd_deserializer = jrdd_deserializer self.is_cached = False self.is_checkpointed = False @@ -69,13 +69,7 @@ def count(self): Return a new DStream in which each RDD has a single element generated by counting each RDD of this DStream. """ - return self.mapPartitions(lambda i: [sum(1 for _ in i)])._sum() - - def _sum(self): - """ - Add up the elements in this DStream. - """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) def filter(self, f): """ @@ -130,7 +124,7 @@ def reduceByKey(self, func, numPartitions=None): Return a new DStream by applying reduceByKey to each RDD. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.combineByKey(lambda x: x, func, func, numPartitions) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, @@ -139,7 +133,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, Return a new DStream by applying combineByKey to each RDD. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism def func(rdd): return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) @@ -156,7 +150,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) @@ -216,7 +210,7 @@ def persist(self, storageLevel): Persist the RDDs of this DStream with the given storage level """ self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdstream.persist(javaStorageLevel) return self @@ -236,7 +230,7 @@ def groupByKey(self, numPartitions=None): Return a new DStream by applying groupByKey on each RDD. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) def countByValue(self): @@ -262,21 +256,22 @@ def saveAsTextFile(t, rdd): raise return self.foreachRDD(saveAsTextFile) - def _saveAsPickleFiles(self, prefix, suffix=None): - """ - Save each RDD in this DStream as at binary file, the elements are - serialized by pickle. - """ - def saveAsPickleFile(t, rdd): - path = rddToFileName(prefix, suffix, t) - try: - rdd.saveAsPickleFile(path) - except Py4JJavaError as e: - # after recovered from checkpointing, the foreachRDD may - # be called twice - if 'FileAlreadyExistsException' not in str(e): - raise - return self.foreachRDD(saveAsPickleFile) + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) def transform(self, func): """ @@ -304,10 +299,10 @@ def transformWith(self, func, other, keepSerializer=False): oldfunc = func func = lambda t, a, b: oldfunc(a, b) assert func.func_code.co_argcount == 3, "func should take two or three arguments" - jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer) - dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) - jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) def repartition(self, numPartitions): @@ -336,61 +331,61 @@ def union(self, other): def cogroup(self, other, numPartitions=None): """ - Return a new DStream by applying 'cogroup' between RDDs of `this` + Return a new DStream by applying 'cogroup' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) def join(self, other, numPartitions=None): """ - Return a new DStream by applying 'join' between RDDs of `this` DStream and + Return a new DStream by applying 'join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.join(b, numPartitions), other) def leftOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'left outer join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) def rightOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'right outer join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) def fullOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'full outer join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) def _jtime(self, timestamp): @@ -398,7 +393,7 @@ def _jtime(self, timestamp): """ if isinstance(timestamp, datetime): timestamp = time.mktime(timestamp.timetuple()) - return self.ctx._jvm.Time(long(timestamp * 1000)) + return self._sc._jvm.Time(long(timestamp * 1000)) def slice(self, begin, end): """ @@ -407,7 +402,7 @@ def slice(self, begin, end): `begin`, `end` could be datetime.datetime() or unix_timestamp """ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) - return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] def _validate_window_param(self, window, slide): duration = self._jdstream.dstream().slideDuration().milliseconds() @@ -532,7 +527,7 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None """ self._validate_window_param(windowDuration, slideDuration) if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism reduced = self.reduceByKey(func, numPartitions) @@ -548,18 +543,18 @@ def invReduceFunc(t, a, b): joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) - jreduceFunc = TransformFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: - jinvReduceFunc = TransformFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None if slideDuration is None: slideDuration = self._slideDuration - dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, self._ssc._jduration(windowDuration), self._ssc._jduration(slideDuration)) - return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) def updateStateByKey(self, updateFunc, numPartitions=None): """ @@ -570,7 +565,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None): If `s` is None, then `k` will be eliminated. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism def reduceFunc(t, a, b): if a is None: @@ -581,10 +576,10 @@ def reduceFunc(t, a, b): state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) - jreduceFunc = TransformFunction(self.ctx, reduceFunc, - self.ctx.serializer, self._jrdd_deserializer) - dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) - return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) class TransformedDStream(DStream): @@ -596,10 +591,9 @@ class TransformedDStream(DStream): one transformation. """ def __init__(self, prev, func): - ssc = prev._ssc - self._ssc = ssc - self.ctx = ssc._sc - self._jrdd_deserializer = self.ctx.serializer + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer self.is_cached = False self.is_checkpointed = False self._jdstream_val = None @@ -618,7 +612,7 @@ def _jdstream(self): if self._jdstream_val is not None: return self._jdstream_val - jfunc = TransformFunction(self.ctx, self.func, self.prev._jrdd_deserializer) - dstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) self._jdstream_val = dstream.asJavaDStream() return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a839faecf9a16..9f5cdff5ed809 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -496,11 +496,11 @@ def updater(it): def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, 0.2) + ssc = StreamingContext(sc, 0.5) dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") - wc.checkpoint(.2) + wc.checkpoint(.5) return ssc cpd = tempfile.mkdtemp("test_streaming_cps") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index e171fb5730616..696dfb969a48a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -36,7 +36,7 @@ import org.apache.spark.streaming.api.java._ /** - * Interface for Python callback function with three arguments + * Interface for Python callback function which is used to transform RDDs */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]