diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19cdbe679fd35..8051b221ac3d1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -52,6 +52,7 @@ private[spark] class PythonRDD( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { + // create a new PythonRDD with same Python setting but different parent. def copyTo(rdd: RDD[_]): PythonRDD = { new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator) diff --git a/examples/src/main/python/streaming/wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py similarity index 100% rename from examples/src/main/python/streaming/wordcount.py rename to examples/src/main/python/streaming/hdfs_wordcount.py diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index db5b97f8472d1..9c70fa5c16d0c 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -23,7 +23,6 @@ import platform from subprocess import Popen, PIPE from threading import Thread - from py4j.java_gateway import java_import, JavaGateway, GatewayClient diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 425b0a96aa832..ae4a1d5b6b069 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -31,6 +31,11 @@ def _daemonize_callback_server(): """ Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. """ # TODO: create a patch for Py4J import socket @@ -47,7 +52,6 @@ def start(self): 1) try: self.server_socket.bind((self.address, self.port)) - # self.port = self.server_socket.getsockname()[1] except Exception: msg = 'An error occurred while trying to start the callback server' logger.exception(msg) @@ -63,19 +67,21 @@ def start(self): class StreamingContext(object): """ - Main entry point for Spark Streaming functionality. A StreamingContext represents the - connection to a Spark cluster, and can be used to create L{DStream}s and - broadcast variables on that cluster. + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream}s various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTransformation()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. """ def __init__(self, sparkContext, duration): """ - Create a new StreamingContext. At least the master and app name and duration - should be set, either through the named parameters here or through C{conf}. + Create a new StreamingContext. @param sparkContext: L{SparkContext} object. - @param duration: seconds for SparkStreaming. - + @param duration: number of seconds. """ self._sc = sparkContext self._jvm = self._sc._jvm @@ -127,8 +133,12 @@ def awaitTermination(self, timeout=None): def stop(self, stopSparkContext=True, stopGraceFully=False): """ - Stop the execution of the streams immediately (does not wait for all received data - to be processed). + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext Stop the associated SparkContext or not + @param stopGracefully Stop gracefully by waiting for the processing + of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) if stopSparkContext: @@ -140,7 +150,7 @@ def remember(self, duration): in the last given duration. DStreams remember RDDs only for a limited duration of time and releases them for garbage collection. This method allows the developer to specify how to long to remember - the RDDs ( if the developer wishes to query old data outside the + the RDDs (if the developer wishes to query old data outside the DStream computation). @param duration Minimum duration (in seconds) that each DStream diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index d866f8c9687fb..4e3f07e26953b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -30,6 +30,24 @@ class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ def __init__(self, jdstream, ssc, jrdd_deserializer): self._jdstream = jdstream self._ssc = ssc @@ -46,11 +64,12 @@ def context(self): def count(self): """ - Return a new DStream which contains the number of elements in this DStream. + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. """ - return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + return self.mapPartitions(lambda i: [sum(1 for _ in i)])._sum() - def sum(self): + def _sum(self): """ Add up the elements in this DStream. """ @@ -66,8 +85,8 @@ def func(iterator): def flatMap(self, f, preservesPartitioning=False): """ - Pass each value in the key-value pair DStream through flatMap function - without changing the keys: this also retains the original RDD's partition. + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results """ def func(s, iterator): return chain.from_iterable(imap(f, iterator)) @@ -83,7 +102,8 @@ def func(iterator): def mapPartitions(self, f, preservesPartitioning=False): """ - Return a new DStream by applying a function to each partition of this DStream. + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. """ def func(s, iterator): return f(iterator) @@ -91,56 +111,51 @@ def func(s, iterator): def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ - Return a new DStream by applying a function to each partition of this DStream, - while tracking the index of the original partition. + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. """ return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) def reduce(self, func): """ - Return a new DStream by reduceing the elements of this RDD using the specified - commutative and associative binary operator. + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. """ return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) def reduceByKey(self, func, numPartitions=None): """ - Merge the value for each key using an associative reduce function. - - This will also perform the merging locally on each mapper before - sending results to reducer, similarly to a "combiner" in MapReduce. - - Output will be hash-partitioned with C{numPartitions} partitions, or - the default parallelism level if C{numPartitions} is not specified. + Return a new DStream by applying reduceByKey to each RDD. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.combineByKey(lambda x: x, func, func, numPartitions) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions=None): """ - Count the number of elements for each key, and return the result to the - master as a dictionary + Return a new DStream by applying combineByKey to each RDD. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + def func(rdd): return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) return self.transform(func) def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ - Return a copy of the DStream partitioned using the specified partitioner. + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. """ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) - def foreach(self, func): - return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) + # def foreach(self, func): + # return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) def foreachRDD(self, func): """ - Apply userdefined function to all RDD in a DStream. - This python implementation could be expensive because it uses callback server - in order to apply function to RDD in DStream. - This is an output operator, so this DStream will be registered as an output - stream and there materialized. + Apply a function to each RDD in this DStream. """ jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream @@ -148,13 +163,12 @@ def foreachRDD(self, func): def pprint(self): """ - Print the first ten elements of each RDD generated in this DStream. This is an output - operator, so this DStream will be registered as an output stream and there materialized. + Print the first ten elements of each RDD generated in this DStream. """ - def takeAndPrint(timestamp, rdd): + def takeAndPrint(time, rdd): taken = rdd.take(11) print "-------------------------------------------" - print "Time: %s" % datetime.fromtimestamp(timestamp / 1000.0) + print "Time: %s" % time print "-------------------------------------------" for record in taken[:10]: print record @@ -164,58 +178,18 @@ def takeAndPrint(timestamp, rdd): self.foreachRDD(takeAndPrint) - def _first(self): - """ - Return the first RDD in the stream. - """ - return self._take(1)[0] - - def _take(self, n): - """ - Return the first `n` RDDs in the stream (will start and stop). - """ - results = [] - - def take(_, rdd): - if rdd and len(results) < n: - results.extend(rdd.take(n - len(results))) - - self.foreachRDD(take) - - self._ssc.start() - while len(results) < n: - time.sleep(0.01) - self._ssc.stop(False, True) - return results - - def _collect(self): - """ - Collect each RDDs into the returned list. - - :return: list, which will have the collected items. - """ - result = [] - - def get_output(_, rdd): - r = rdd.collect() - result.append(r) - self.foreachRDD(get_output) - return result - def mapValues(self, f): """ - Pass each value in the key-value pair RDD through a map function - without changing the keys; this also retains the original RDD's - partitioning. + Return a new DStream by applying a map function to the value of + each key-value pairs in 'this' DStream without changing the key. """ map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): """ - Pass each value in the key-value pair RDD through a flatMap function - without changing the keys; this also retains the original RDD's - partitioning. + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in 'this' DStream without changing the key. """ flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) return self.flatMap(flat_map_fn, preservesPartitioning=True) @@ -223,8 +197,7 @@ def flatMapValues(self, f): def glom(self): """ Return a new DStream in which RDD is generated by applying glom() - to RDD of this DStream. Applying glom() to an RDD coalesces all - elements within each partition into an list. + to RDD of this DStream. """ def func(iterator): yield list(iterator) @@ -232,7 +205,8 @@ def func(iterator): def cache(self): """ - Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}). + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). """ self.is_cached = True self.persist(StorageLevel.MEMORY_ONLY_SER) @@ -240,9 +214,7 @@ def cache(self): def persist(self, storageLevel): """ - Set this DStream's storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the DStream does not have a storage level set yet. + Persist the RDDs of this DStream with the given storage level """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -251,11 +223,10 @@ def persist(self, storageLevel): def checkpoint(self, interval): """ - Mark this DStream for checkpointing. It will be saved to a file inside the - checkpoint directory set with L{SparkContext.setCheckpointDir()} + Enable periodic checkpointing of RDDs of this DStream - @param interval: time in seconds, after which generated RDD will - be checkpointed + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed """ self.is_checkpointed = True self._jdstream.checkpoint(self._ssc._jduration(interval)) @@ -263,85 +234,76 @@ def checkpoint(self, interval): def groupByKey(self, numPartitions=None): """ - Return a new DStream which contains group the values for each key in the - DStream into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions in - the DStream. - - Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much - better performance. + Return a new DStream by applying groupByKey on each RDD. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) def countByValue(self): """ - Return new DStream which contains the count of each unique value in this - DStreeam as a (value, count) pairs. + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. """ return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() def saveAsTextFiles(self, prefix, suffix=None): """ - Save this DStream as a text file, using string representations of elements. + Save each RDD in this DStream as at text file, using string + representation of elements. """ - def saveAsTextFile(time, rdd): - """ - Closure to save element in RDD in DStream as Pickled data in file. - This closure is called by py4j callback server. - """ path = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(path) - return self.foreachRDD(saveAsTextFile) - def saveAsPickleFiles(self, prefix, suffix=None): + def _saveAsPickleFiles(self, prefix, suffix=None): """ - Save this DStream as a SequenceFile of serialized objects. The serializer - used is L{pyspark.serializers.PickleSerializer}, default batch size - is 10. + Save each RDD in this DStream as at binary file, the elements are + serialized by pickle. """ - def saveAsPickleFile(time, rdd): - """ - Closure to save element in RDD in the DStream as Pickled data in file. - This closure is called by py4j callback server. - """ path = rddToFileName(prefix, suffix, time) rdd.saveAsPickleFile(path) - return self.foreachRDD(saveAsPickleFile) def transform(self, func): """ Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream. - """ - return TransformedDStream(self, lambda t, a: func(a), True) - def transformWithTime(self, func): + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) """ - Return a new DStream in which each RDD is generated by applying a function - on each RDD of 'this' DStream. - """ - return TransformedDStream(self, func, False) + resue = False + if func.func_code.co_argcount == 1: + reuse = True + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func, reuse) def transformWith(self, func, other, keepSerializer=False): """ Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) """ - jfunc = RDDFunction(self.ctx, lambda t, a, b: func(a, b), self._jrdd_deserializer) + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) - def repartitions(self, numPartitions): + def repartition(self, numPartitions): """ - Return a new DStream with an increased or decreased level of parallelism. Each RDD in the - returned DStream has exactly numPartitions partitions. + Return a new DStream with an increased or decreased level of parallelism. """ return self.transform(lambda rdd: rdd.repartition(numPartitions)) @@ -355,7 +317,8 @@ def _slideDuration(self): def union(self, other): """ Return a new DStream by unifying data of another DStream with this DStream. - @param other Another DStream having the same interval (i.e., slideDuration) as this DStream. + @param other Another DStream having the same interval (i.e., slideDuration) + as this DStream. """ if self._slideDuration != other._slideDuration: raise ValueError("the two DStream should have same slide duration") @@ -368,6 +331,8 @@ def cogroup(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) def join(self, other, numPartitions=None): @@ -378,6 +343,8 @@ def join(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.join(b, numPartitions), other) def leftOuterJoin(self, other, numPartitions=None): @@ -388,6 +355,8 @@ def leftOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) def rightOuterJoin(self, other, numPartitions=None): @@ -398,6 +367,8 @@ def rightOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) def fullOuterJoin(self, other, numPartitions=None): @@ -408,6 +379,8 @@ def fullOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) def _jtime(self, timestamp): @@ -426,7 +399,7 @@ def slice(self, begin, end): jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] - def _check_window(self, window, slide): + def _validate_window_param(self, window, slide): duration = self._jdstream.dstream().slideDuration().milliseconds() if int(window * 1000) % duration != 0: raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" @@ -446,7 +419,7 @@ def window(self, windowDuration, slideDuration=None): the new DStream will generate RDDs); must be a multiple of this DStream's batching interval """ - self._check_window(windowDuration, slideDuration) + self._validate_window_param(windowDuration, slideDuration) d = self._ssc._jduration(windowDuration) if slideDuration is None: return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) @@ -547,23 +520,22 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None only pairs that satisfy the function are retained set this to null if you do not want to filter """ - self._check_window(windowDuration, slideDuration) - reduced = self.reduceByKey(func) + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) def reduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - # use the average of number of partitions, or it will keep increasing - partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 - r = a.union(b).reduceByKey(func, partitions) if a else b + r = a.union(b).reduceByKey(func, numPartitions) if a else b if filterFunc: r = r.filter(filterFunc) return r def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - # use the average of number of partitions, or it will keep increasing - partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 - joined = a.leftOuterJoin(b, partitions) + joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) @@ -587,13 +559,14 @@ def updateStateByKey(self, updateFunc, numPartitions=None): @param updateFunc State update function ([(k, vs, s)] -> [(k, s)]). If `s` is None, then `k` will be eliminated. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: - # use the average of number of partitions, or it will keep increasing - partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 - g = a.cogroup(b, partitions) + g = a.cogroup(b, numPartitions) g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None)) state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) @@ -605,6 +578,13 @@ def reduceFunc(t, a, b): class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ def __init__(self, prev, func, reuse=False): ssc = prev._ssc self._ssc = ssc diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0dc6b3d675397..698978e61ffad 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -29,17 +29,50 @@ class PySparkStreamingTestCase(unittest.TestCase): timeout = 10 # seconds + duration = 1 def setUp(self): class_name = self.__class__.__name__ self.sc = SparkContext(appName=class_name) self.sc.setCheckpointDir("/tmp") # TODO: decrease duration to speed up tests - self.ssc = StreamingContext(self.sc, duration=1) + self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): self.ssc.stop() + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + while len(results) < n: + time.sleep(0.01) + self.ssc.stop(False, True) + return results + + def _collect(self, dstream): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + r = rdd.collect() + result.append(r) + dstream.foreachRDD(get_output) + return result + def _test_func(self, input, func, expected, sort=False, input2=None): """ @param input: dataset for the test. This should be list of lists. @@ -59,7 +92,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None): else: stream = func(input_stream) - result = stream._collect() + result = self._collect(stream) self.ssc.start() start_time = time.time() @@ -89,16 +122,6 @@ def _sort_result_based_on_key(self, outputs): class TestBasicOperations(PySparkStreamingTestCase): - def test_take(self): - input = [range(i) for i in range(3)] - dstream = self.ssc.queueStream(input) - self.assertEqual([0, 0, 1], dstream._take(3)) - - def test_first(self): - input = [range(10)] - dstream = self.ssc.queueStream(input) - self.assertEqual(0, dstream._first()) - def test_map(self): """Basic operation test for DStream.map.""" input = [range(1, 5), range(5, 9), range(9, 13)] @@ -248,7 +271,7 @@ def test_repartition(self): rdds = [self.sc.parallelize(r, 2) for r in input] def func(dstream): - return dstream.repartitions(1).glom() + return dstream.repartition(1).glom() expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] self._test_func(rdds, func, expected) @@ -395,15 +418,9 @@ def func(dstream): self._test_func(input, func, expected) -class TestStreamingContext(unittest.TestCase): - def setUp(self): - self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__) - self.batachDuration = 0.1 - self.ssc = StreamingContext(self.sc, self.batachDuration) +class TestStreamingContext(PySparkStreamingTestCase): - def tearDown(self): - self.ssc.stop() - self.sc.stop() + duration = 0.1 def test_stop_only_streaming_context(self): self._addInputStream() @@ -421,12 +438,12 @@ def _addInputStream(self): # Make sure each length of input is over 3 inputs = map(lambda x: range(1, x), range(5, 101)) stream = self.ssc.queueStream(inputs) - stream._collect() + self._collect(stream) def test_queueStream(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) - result = dstream._collect() + result = self._collect(dstream) self.ssc.start() time.sleep(1) self.assertEqual(input, result[:3]) @@ -445,7 +462,7 @@ def test_queueStream(self): # # self.ssc = StreamingContext(self.sc, self.batachDuration) # dstream2 = self.ssc.textFileStream(d) - # result = dstream2._collect() + # result = self._collect(dstream2) # self.ssc.start() # time.sleep(2) # self.assertEqual(input, result[:3]) @@ -455,7 +472,7 @@ def test_union(self): dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) - result = dstream3._collect() + result = self._collect(dstream3) self.ssc.start() time.sleep(1) expected = [i * 2 for i in input] @@ -472,7 +489,7 @@ def func(rdds): dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - self.assertEqual([2, 3, 1], dstream._take(3)) + self.assertEqual([2, 3, 1], self._take(dstream, 3)) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 57791805e8f9f..4838ec6c8c6e9 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -15,6 +15,8 @@ # limitations under the License. # +from datetime import datetime + from pyspark.rdd import RDD @@ -40,7 +42,8 @@ def call(self, milliseconds, jrdds): rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD for jrdd, ser in zip(jrdds, sers)] - r = self.func(milliseconds, *rdds) + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) if r: return r._jrdd except Exception: diff --git a/python/run-tests b/python/run-tests index e8796838c22c1..e86e0729cf65e 100755 --- a/python/run-tests +++ b/python/run-tests @@ -93,9 +93,9 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -#run_core_tests -#run_sql_tests -#run_mllib_tests +run_core_tests +run_sql_tests +run_mllib_tests run_streaming_tests # Try to test with PyPy diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 30c52c15e9e68..658715eb456dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -34,7 +34,8 @@ import org.apache.spark.streaming.api.java._ /** * Interface for Python callback function with three arguments */ -trait PythonRDDFunction { +private[spark] trait PythonRDDFunction { + // callback in Python def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } @@ -44,38 +45,30 @@ trait PythonRDDFunction { private[python] class RDDFunction(pfunc: PythonRDDFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { - def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { - if (rdd.isDefined) { - JavaRDD.fromRDD(rdd.get) - } else { - null - } - } - - def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { - if (jrdd != null) { - Some(jrdd.rdd) - } else { - None - } - } - def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - some(pfunc.call(time.milliseconds, List(wrapRDD(rdd)).asJava)) + PythonDStream.some(pfunc.call(time.milliseconds, List(PythonDStream.wrapRDD(rdd)).asJava)) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava)) + val rdds = List(PythonDStream.wrapRDD(rdd), PythonDStream.wrapRDD(rdd2)).asJava + PythonDStream.some(pfunc.call(time.milliseconds, rdds)) } - // for JFunction2 + // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { pfunc.call(time.milliseconds, rdds) } } + +/** + * Base class for PythonDStream with some common methods + */ private[python] -abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (parent.ssc) { +abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new RDDFunction(pfunc) override def dependencies = List(parent) @@ -84,12 +77,33 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p val asJavaDStream = JavaDStream.fromDStream(this) } +/** + * Helper functions + */ private[spark] object PythonDStream { + // convert Option[RDD[_]] to JavaRDD, handle null gracefully + def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { + if (rdd.isDefined) { + JavaRDD.fromRDD(rdd.get) + } else { + null + } + } + + // convert JavaRDD to Option[RDD[Array[Byte]]] to , handle null gracefully + def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { + if (jrdd != null) { + Some(jrdd.rdd) + } else { + None + } + } + // helper function for DStream.foreachRDD(), // cannot be `foreachRDD`, it will confusing py4j - def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction){ - val func = new RDDFunction(pyfunc) + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonRDDFunction){ + val func = new RDDFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) } @@ -112,34 +126,36 @@ private[spark] object PythonDStream { /** * Transformed DStream in Python. * - * If the result RDD is PythonRDD, then it will cache it as an template for future use, - * this can reduce the Python callbacks. + * If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it + * as an template for future use, this can reduce the Python callbacks. */ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, var reuse: Boolean = false) - extends PythonDStream(parent) { + extends PythonDStream(parent, pfunc) { - val func = new RDDFunction(pfunc) + // rdd returned by func var lastResult: PythonRDD = _ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val rdd1 = parent.getOrCompute(validTime) - if (rdd1.isEmpty) { + val rdd = parent.getOrCompute(validTime) + if (rdd.isEmpty) { return None } if (reuse && lastResult != null) { - Some(lastResult.copyTo(rdd1.get)) + // use the previous result as the template to generate new RDD + Some(lastResult.copyTo(rdd.get)) } else { - val r = func(rdd1, validTime) + val r = func(rdd, validTime) if (reuse && r.isDefined && lastResult == null) { + // try to use the result as a template r.get match { - case rdd: PythonRDD => - if (rdd.parent(0) == rdd1) { + case pyrdd: PythonRDD => + if (pyrdd.parent(0) == rdd) { // only one PythonRDD - lastResult = rdd + lastResult = pyrdd } else { - // may have multiple stages + // maybe have multiple stages, don't check it anymore reuse = false } } @@ -174,10 +190,8 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], * similar to StateDStream */ private[spark] -class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction) - extends PythonDStream(parent) { - - val reduceFunc = new RDDFunction(preduceFunc) +class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction) + extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -186,7 +200,7 @@ class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFun val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - reduceFunc(lastState, rdd, validTime) + func(lastState, rdd, validTime) } else { lastState } @@ -244,7 +258,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // add the RDDs of the reduced values in "new time steps" val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) if (newRDDs.size > 0) { - reduceFunc(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) } else { subtracted } @@ -252,7 +266,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // Get the RDDs of the reduced values in current window val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) if (currentRDDs.size > 0) { - reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime) + func(None, Some(ssc.sc.union(currentRDDs)), validTime) } else { None }