Skip to content

Commit

Permalink
Adding support of initial value for state update.
Browse files Browse the repository at this point in the history
SPARK-3660 : Initial RDD for updateStateByKey transformation
  • Loading branch information
soumitrak committed Oct 6, 2014
1 parent 8d22dbb commit fdd7db3
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.streaming

import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._

/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
* second starting with initial value of word count.
* Usage: StatefulNetworkWordCountWithInitial <hostname> <port>
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
* data.
*
* To run this on your local machine, you need to first run a Netcat server
* `$ nc -lk 9999`
* and then run the example
* `$ bin/run-example
* org.apache.spark.examples.streaming.StatefulNetworkWordCountWithInitial localhost 9999`
*/
object StatefulNetworkWordCountWithInitial {
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: StatefulNetworkWordCountWithInitial <hostname> <port>")
System.exit(1)
}

StreamingExamples.setStreamingLogLevels()

val updateFunc = (values: Seq[Int], state: Option[Int]) => {
val currentCount = values.sum

val previousCount = state.getOrElse(0)

Some(currentCount + previousCount)
}

val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}

val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCountWithInitial")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to updateStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x => (x, 1))

// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,31 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
updateStateByKey(newUpdateFunc, partitioner, true)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated. Note, that
* this function may generate a different a tuple with a different key
* than the input key. It is up to the developer to decide whether to
* remember the partitioner despite the key being changed.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream
* @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
* @param initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean,
initial : RDD[(K, S)]
): DStream[(K, S)] = {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
rememberPartitioner, Some(initial))
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
Expand All @@ -413,7 +438,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
partitioner: Partitioner,
rememberPartitioner: Boolean
): DStream[(K, S)] = {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
preservePartitioning: Boolean
preservePartitioning: Boolean,
initial : Option[RDD[(K, S)]]
) extends DStream[(K, S)](parent.ssc) {

super.persist(StorageLevel.MEMORY_ONLY_SER)
Expand All @@ -41,6 +42,28 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](

override val mustCheckpoint = true

private [this] def computeUsingPreviousRDD (
parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = itr.hasNext match {
case true => Some(itr.next())
case false => None
}
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
}

override def compute(validTime: Time): Option[RDD[(K, S)]] = {

// Try to get the previous state RDD
Expand All @@ -51,25 +74,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual

// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = itr.hasNext match {
case true => Some(itr.next())
case false => None
}
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
computeUsingPreviousRDD (parentRDD, prevStateRDD)
}
case None => { // If parent RDD does not exist

Expand All @@ -90,19 +95,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
initial match {
case None => {
// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => {
updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
}

// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
val groupedRDD = parentRDD.groupByKey (partitioner)
val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some (sessionRDD)
}
case Some (initialRDD) => {
computeUsingPreviousRDD(parentRDD, initialRDD)
}
}

val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some(sessionRDD)
}
case None => { // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
Expand Down

0 comments on commit fdd7db3

Please sign in to comment.