Skip to content

Commit

Permalink
change updateStateByKey() to easy API
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 10, 2014
1 parent 182be73 commit 3e2492b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 17 deletions.
57 changes: 57 additions & 0 deletions examples/src/main/python/streaming/stateful_network_wordcount.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Counts words in UTF8 encoded, '\n' delimited text received from the
network every second.
Usage: stateful_network_wordcount.py <hostname> <port>
<hostname> and <port> describe the TCP server that Spark Streaming
would connect to receive data.
To run this on your local machine, you need to first run a Netcat server
`$ nc -lk 9999`
and then run the example
`$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
localhost 9999`
"""

import sys

from pyspark import SparkContext
from pyspark.streaming import StreamingContext

if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: stateful_network_wordcount.py <hostname> <port>"
exit(-1)
sc = SparkContext(appName="PythonStreamingNetworkWordCount")
ssc = StreamingContext(sc, 1)
ssc.checkpoint("checkpoint")

def updateFunc(new_values, last_sum):
return sum(new_values) + (last_sum or 0)

lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
running_counts = lines.flatMap(lambda line: line.split(" "))\
.map(lambda word: (word, 1))\
.updateStateByKey(updateFunc)

running_counts.pprint()

ssc.start()
ssc.awaitTermination()
10 changes: 5 additions & 5 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,19 +564,19 @@ def updateStateByKey(self, updateFunc, numPartitions=None):
Return a new "state" DStream where the state for each key is updated by applying
the given function on the previous state of the key and the new values of the key.
@param updateFunc: State update function ([(k, vs, s)] -> [(k, s)]).
If `s` is None, then `k` will be eliminated.
@param updateFunc: State update function. If this function returns None, then
corresponding state key-value pair will be eliminated.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism

def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
g = a.cogroup(b, numPartitions)
g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None))
state = g.mapPartitions(lambda x: updateFunc(x))
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)

jreduceFunc = TransformFunction(self._sc, reduceFunc,
Expand Down
22 changes: 10 additions & 12 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _sort_result_based_on_key(self, outputs):
output.sort(key=lambda x: x[0])


class TestBasicOperations(PySparkStreamingTestCase):
class BasicOperationTests(PySparkStreamingTestCase):

def test_map(self):
"""Basic operation test for DStream.map."""
Expand Down Expand Up @@ -340,15 +340,13 @@ def func(a, b):
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)

def update_state_by_key(self):
def test_update_state_by_key(self):

def updater(it):
for k, vs, s in it:
if not s:
s = vs
else:
s.extend(vs)
yield (k, s)
def updater(vs, s):
if not s:
s = []
s.extend(vs)
return s

input = [[('k', i)] for i in range(5)]

Expand All @@ -360,7 +358,7 @@ def func(dstream):
self._test_func(input, func, expected)


class TestWindowFunctions(PySparkStreamingTestCase):
class WindowFunctionTests(PySparkStreamingTestCase):

timeout = 20

Expand Down Expand Up @@ -417,7 +415,7 @@ def test_reduce_by_invalid_window(self):
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))


class TestStreamingContext(PySparkStreamingTestCase):
class StreamingContextTests(PySparkStreamingTestCase):

duration = 0.1

Expand Down Expand Up @@ -480,7 +478,7 @@ def func(rdds):
self.assertEqual([2, 3, 1], self._take(dstream, 3))


class TestCheckpoint(PySparkStreamingTestCase):
class CheckpointTests(PySparkStreamingTestCase):

def setUp(self):
pass
Expand Down

0 comments on commit 3e2492b

Please sign in to comment.