Skip to content

Commit

Permalink
[SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation
Browse files Browse the repository at this point in the history
SPARK-3660 : Initial RDD for updateStateByKey transformation

I have added a sample StatefulNetworkWordCountWithInitial inspired by StatefulNetworkWordCount.

Please let me know if any changes are required.

Author: Soumitra Kumar <kumar.soumitra@gmail.com>

Closes apache#2665 from soumitrak/master and squashes the following commits:

ee8980b [Soumitra Kumar] Fixed copy/paste issue.
304f636 [Soumitra Kumar] Added simpler version of updateStateByKey API with initialRDD and test.
9781135 [Soumitra Kumar] Fixed test, and renamed variable.
3da51a2 [Soumitra Kumar] Adding updateStateByKey with initialRDD API to JavaPairDStream.
2f78f7e [Soumitra Kumar] Merge remote-tracking branch 'upstream/master'
d4fdd18 [Soumitra Kumar] Renamed variable and moved method.
d0ce2cd [Soumitra Kumar] Merge remote-tracking branch 'upstream/master'
31399a4 [Soumitra Kumar] Merge remote-tracking branch 'upstream/master'
4efa58b [Soumitra Kumar] [SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation
8f40ca0 [Soumitra Kumar] Merge remote-tracking branch 'upstream/master'
dde4271 [Soumitra Kumar] Merge remote-tracking branch 'upstream/master'
fdd7db3 [Soumitra Kumar] Adding support of initial value for state update. SPARK-3660 : Initial RDD for updateStateByKey transformation
  • Loading branch information
soumitrak authored and tianyi committed Dec 4, 2014
1 parent 4815caa commit 1791de8
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.examples.streaming

import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._

/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
* second.
* second starting with initial value of word count.
* Usage: StatefulNetworkWordCount <hostname> <port>
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
* data.
Expand Down Expand Up @@ -51,20 +52,28 @@ object StatefulNetworkWordCount {
Some(currentCount + previousCount)
}

val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}

val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Create a ReceiverInputDStream on target ip:port and count the
// Initial RDD input to updateStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x => (x, 1))

// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,25 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S](
updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
partitioner: Partitioner,
initialRDD: JavaPairRDD[K, S]
): JavaPairDStream[K, S] = {
implicit val cm: ClassTag[S] = fakeClassTag
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner, initialRDD)
}

/**
* Return a new DStream by applying a map function to the value of each key-value pairs in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,54 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
partitioner: Partitioner,
rememberPartitioner: Boolean
): DStream[(K, S)] = {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
partitioner: Partitioner,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = {
val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}
updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated. Note, that
* this function may generate a different a tuple with a different key
* than the input key. It is up to the developer to decide whether to
* remember the partitioner despite the key being changed.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream
* @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
rememberPartitioner, Some(initialRDD))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
preservePartitioning: Boolean
preservePartitioning: Boolean,
initialRDD : Option[RDD[(K, S)]]
) extends DStream[(K, S)](parent.ssc) {

super.persist(StorageLevel.MEMORY_ONLY_SER)
Expand All @@ -41,6 +42,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](

override val mustCheckpoint = true

private [this] def computeUsingPreviousRDD (
parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = if(itr.hasNext) Some(itr.next) else None
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
}

override def compute(validTime: Time): Option[RDD[(K, S)]] = {

// Try to get the previous state RDD
Expand All @@ -51,25 +71,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual

// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = itr.hasNext match {
case true => Some(itr.next())
case false => None
}
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
computeUsingPreviousRDD (parentRDD, prevStateRDD)
}
case None => { // If parent RDD does not exist

Expand All @@ -90,19 +92,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
initialRDD match {
case None => {
// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => {
updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
}

// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
val groupedRDD = parentRDD.groupByKey (partitioner)
val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some (sessionRDD)
}
case Some (initialStateRDD) => {
computeUsingPreviousRDD(parentRDD, initialStateRDD)
}
}

val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some(sessionRDD)
}
case None => { // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,17 @@ public void testUnion() {
* Performs an order-invariant comparison of lists representing two RDD streams. This allows
* us to account for ordering variation within individual RDD's which occurs during windowing.
*/
public static <T extends Comparable<T>> void assertOrderInvariantEquals(
public static <T> void assertOrderInvariantEquals(
List<List<T>> expected, List<List<T>> actual) {
List<Set<T>> expectedSets = new ArrayList<Set<T>>();
for (List<T> list: expected) {
Collections.sort(list);
expectedSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
List<Set<T>> actualSets = new ArrayList<Set<T>>();
for (List<T> list: actual) {
Collections.sort(list);
actualSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
Assert.assertEquals(expected, actual);
Assert.assertEquals(expectedSets, actualSets);
}


Expand Down Expand Up @@ -1239,6 +1241,49 @@ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
Assert.assertEquals(expected, result);
}

@SuppressWarnings("unchecked")
@Test
public void testUpdateStateByKeyWithInitial() {
List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;

List<Tuple2<String, Integer>> initial = Arrays.asList (
new Tuple2<String, Integer> ("california", 1),
new Tuple2<String, Integer> ("new york", 2));

JavaRDD<Tuple2<String, Integer>> tmpRDD = ssc.sparkContext().parallelize(initial);
JavaPairRDD<String, Integer> initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD);

List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
Arrays.asList(new Tuple2<String, Integer>("california", 5),
new Tuple2<String, Integer>("new york", 7)),
Arrays.asList(new Tuple2<String, Integer>("california", 15),
new Tuple2<String, Integer>("new york", 11)),
Arrays.asList(new Tuple2<String, Integer>("california", 15),
new Tuple2<String, Integer>("new york", 11)));

JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);

JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
@Override
public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
int out = 0;
if (state.isPresent()) {
out = out + state.get();
}
for (Integer v: values) {
out = out + v;
}
return Optional.of(out);
}
}, new HashPartitioner(1), initialRDD);
JavaTestUtils.attachTestOutputStream(updated);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);

assertOrderInvariantEquals(expected, result);
}

@SuppressWarnings("unchecked")
@Test
public void testReduceByKeyAndWindowWithInverse() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
import org.apache.spark.HashPartitioner

class BasicOperationsSuite extends TestSuiteBase {
test("map") {
Expand Down Expand Up @@ -350,6 +351,79 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - simple with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val outputData =
Seq(
Seq(("a", 2), ("c", 2)),
Seq(("a", 3), ("b", 1), ("c", 2)),
Seq(("a", 4), ("b", 2), ("c", 3)),
Seq(("a", 5), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3))
)

val updateStateOperation = (s: DStream[String]) => {
val initialRDD = s.context.sparkContext.makeRDD(initial)
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc,
new HashPartitioner (numInputPartitions), initialRDD)
}

testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val outputData =
Seq(
Seq(("a", 2), ("c", 2)),
Seq(("a", 3), ("b", 1), ("c", 2)),
Seq(("a", 4), ("b", 2), ("c", 3)),
Seq(("a", 5), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3))
)

val updateStateOperation = (s: DStream[String]) => {
val initialRDD = s.context.sparkContext.makeRDD(initial)
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}
s.map(x => (x, 1)).updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (numInputPartitions), true, initialRDD)
}

testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - object lifecycle") {
val inputData =
Seq(
Expand Down

0 comments on commit 1791de8

Please sign in to comment.