From f6ce899abd435f36f7c5907523c643cc8b0e61ed Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Jan 2015 13:28:39 -0800 Subject: [PATCH] add example and fix bugs --- .../apache/spark/api/python/PythonRDD.scala | 55 ++++++++++++------- .../main/python/streaming/kafka_wordcount.py | 55 +++++++++++++++++++ python/pyspark/serializers.py | 7 ++- python/pyspark/streaming/kafka.py | 8 ++- python/pyspark/streaming/mqtt.py | 53 ------------------ 5 files changed, 100 insertions(+), 78 deletions(-) create mode 100644 examples/src/main/python/streaming/kafka_wordcount.py delete mode 100644 python/pyspark/streaming/mqtt.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index bad40e6529f74..b47b381374dc7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -313,6 +313,7 @@ private object SpecialLengths { val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 val END_OF_STREAM = -4 + val NULL = -5 } private[spark] object PythonRDD extends Logging { @@ -374,49 +375,61 @@ private[spark] object PythonRDD extends Logging { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the // entire Spark API, I have to use this hacky approach: + def write(bytes: Array[Byte]) { + if (bytes == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + } + def writeS(str: String) { + if (str == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + writeUTF(str, dataOut) + } + } if (iter.hasNext) { val first = iter.next() val newIter = Seq(first).iterator ++ iter first match { case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(write) case string: String => - newIter.asInstanceOf[Iterator[String]].foreach { str => - writeUTF(str, dataOut) - } + newIter.asInstanceOf[Iterator[String]].foreach(writeS) case stream: PortableDataStream => newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + write(stream.toArray()) } case (key: String, stream: PortableDataStream) => newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { case (key, stream) => - writeUTF(key, dataOut) - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + writeS(key) + write(stream.toArray()) } case (key: String, value: String) => newIter.asInstanceOf[Iterator[(String, String)]].foreach { case (key, value) => - writeUTF(key, dataOut) - writeUTF(value, dataOut) + writeS(key) + writeS(value) } case (key: Array[Byte], value: Array[Byte]) => newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { case (key, value) => - dataOut.writeInt(key.length) - dataOut.write(key) - dataOut.writeInt(value.length) - dataOut.write(value) + write(key) + write(value) + } + // key is null + case (null, v:Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + write(key) + write(value) } + case other => - throw new SparkException("Unexpected element type " + first.getClass) + throw new SparkException("Unexpected element type " + other.getClass) } } } diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py new file mode 100644 index 0000000000000..400c05fb7a05b --- /dev/null +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first + $ bin/zookeeper-server-start.sh config/zookeeper.properties + $ bin/kafka-server-start.sh config/server.properties + $ bin/kafka-console-producer.sh --broker-list localhost:9092 --topic test + + and then run the example + `$ bin/spark-submit --driver-class-path lib_managed/jars/kafka_*.jar:\ + external/kafka/target/scala-*/spark-streaming-kafka_*.jar examples/src/main/python/\ + streaming/kafka_wordcount.py localhost:2181 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingKafkaWordCount") + ssc = StreamingContext(sc, 1) + + zkQuorum, topic = sys.argv[1:] + lines = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1}) + counts = lines.map(lambda x: x[1]).flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index bd08c9a6d20d6..3cec646f3336d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -70,6 +70,7 @@ class SpecialLengths(object): PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 END_OF_STREAM = -4 + NULL = -5 class Serializer(object): @@ -145,8 +146,10 @@ def _read_with_length(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + if length == SpecialLengths.NULL: + return None obj = stream.read(length) - if obj == "": + if len(obj) < length: raise EOFError return self.loads(obj) @@ -480,6 +483,8 @@ def loads(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + if length == SpecialLengths.NULL: + return None s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index c27699fee6b83..f52d0b535094f 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -22,11 +22,12 @@ from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils'] +__all__ = ['KafkaUtils', 'utf8_decoder'] def utf8_decoder(s): - return s.decode('utf-8') + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') class KafkaUtils(object): @@ -70,7 +71,8 @@ def getClassByName(name): jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, jparam, jtopics, jlevel) except Py4JError, e: - if 'call a package' in e.message: + # TODO: use --jar once it also work on driver + if not e.message or 'call a package' in e.message: print "No kafka package, please build it and add it into classpath:" print " $ sbt/sbt streaming-kafka/package" print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py deleted file mode 100644 index 78e4cb7f14649..0000000000000 --- a/python/pyspark/streaming/mqtt.py +++ /dev/null @@ -1,53 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from py4j.java_gateway import java_import, Py4JError - -from pyspark.storagelevel import StorageLevel -from pyspark.serializers import UTF8Deserializer -from pyspark.streaming import DStream - -__all__ = ['MQTTUtils'] - - -class MQTTUtils(object): - - @staticmethod - def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): - """ - Create an input stream that receives messages pushed by a MQTT publisher. - - :param ssc: StreamingContext object - :param brokerUrl: Url of remote MQTT publisher - :param topic: Topic name to subscribe to - :param storageLevel: RDD storage level. - :return: A DStream object - """ - java_import(ssc._jvm, "org.apache.spark.streaming.mqtt.MQTTUtils") - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - try: - jstream = ssc._jvm.MQTTUtils.createStream(ssc._jssc, brokerUrl, topic, jlevel) - except Py4JError, e: - if 'call a package' in e.message: - print "No MQTT package, please build it and add it into classpath:" - print " $ sbt/sbt streaming-mqtt/package" - print " $ bin/submit --driver-class-path external/mqtt/target/scala-2.10/" \ - "spark-streaming-mqtt_2.10-1.3.0-SNAPSHOT.jar" - raise Exception("No mqtt package") - raise e - return DStream(jstream, ssc, UTF8Deserializer())