Skip to content

Commit

Permalink
[SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
soumitrak committed Oct 25, 2014
1 parent 8f40ca0 commit 4efa58b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.examples.streaming

import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._

/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
* second.
* second starting with initial value of word count.
* Usage: StatefulNetworkWordCount <hostname> <port>
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
* data.
Expand Down Expand Up @@ -51,11 +52,18 @@ object StatefulNetworkWordCount {
Some(currentCount + previousCount)
}

val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}

val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to updateStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
Expand All @@ -64,7 +72,8 @@ object StatefulNetworkWordCount {

// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.streaming.StreamingContext._

import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.SparkContext._
import org.apache.spark.HashPartitioner

import util.ManualClock
import org.apache.spark.{SparkException, SparkConf}
Expand Down Expand Up @@ -349,6 +350,43 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val outputData =
Seq(
Seq(("a", 2), ("c", 2)),
Seq(("a", 3), ("b", 1), ("c", 2)),
Seq(("a", 4), ("b", 2), ("c", 3)),
Seq(("a", 5), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3))
)

val updateStateOperation = (s: DStream[String], initialRDD : RDD[(String, Int)]) => {
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}
s.map(x => (x, 1)).updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (numInputPartitions), true, initialRDD)
}

testOperationWithInitial(initial, inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - object lifecycle") {
val inputData =
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,34 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
ssc
}

/**
* Set up required DStreams to test the DStream operation using the sequence
* of input collections, and initial sequence.
*/
def setupStreamsWithInitial[U: ClassTag, V: ClassTag](
initial: Seq[V],
input: Seq[Seq[U]],
operation: (DStream[U], RDD[V]) => DStream[V],
numPartitions: Int = numInputPartitions
): StreamingContext = {
// Create StreamingContext
val ssc = new StreamingContext(conf, batchDuration)
if (checkpointDir != null) {
ssc.checkpoint(checkpointDir)
}

// Create initial value RDD
val initialRDD = ssc.sc.makeRDD(initial, numInputPartitions)

// Setup the stream computation
val inputStream = new TestInputStream(ssc, input, numPartitions)
val operatedStream = operation(inputStream, initialRDD)
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
outputStream.register()
ssc
}

/**
* Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
* returns the collected output. It will wait until `numExpectedOutput` number of
Expand Down Expand Up @@ -321,6 +349,23 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
logInfo("Output verified successfully")
}

/**
* Test unary DStream operation with a list of inputs, initial values, with number of
* batches to run same as the number of expected output values
*/
def testOperationWithInitial[U: ClassTag, V: ClassTag](
initial: Seq[V],
input: Seq[Seq[U]],
operation: (DStream[U], RDD[V]) => DStream[V],
expectedOutput: Seq[Seq[V]],
useSet: Boolean = false
) {
val numBatches_ = expectedOutput.size
val ssc = setupStreamsWithInitial[U, V](initial, input, operation)
val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
verifyOutput[V](output, expectedOutput, useSet)
}

/**
* Test unary DStream operation with a list of inputs, with number of
* batches to run same as the number of expected output values
Expand Down

0 comments on commit 4efa58b

Please sign in to comment.