Skip to content

Commit

Permalink
Streaming KMeans with decay
Browse files Browse the repository at this point in the history
- Used trainOn and predictOn pattern, similar to
StreamingLinearAlgorithm
- Decay factor can be set explicitly, or via fractional decay
parameters expressed in units of number of batches, or number of points
- Unit tests for basic functionality and decay settings
  • Loading branch information
freeman-lab committed Aug 28, 2014
1 parent 31f0b07 commit b93350f
Show file tree
Hide file tree
Showing 2 changed files with 337 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package org.apache.spark.mllib.clustering

import breeze.linalg.{Vector => BV}

import scala.reflect.ClassTag
import scala.util.Random._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.StreamingContext._

@DeveloperApi
class StreamingKMeansModel(
override val clusterCenters: Array[Vector],
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) {

/** do a sequential KMeans update on a batch of data **/
def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = {

val centers = clusterCenters
val counts = clusterCounts

// find nearest cluster to each point
val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong)))

// get sums and counts for updating each cluster
type WeightedPoint = (BV[Double], Long)
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
(p1._1 += p2._1, p1._2 + p2._2)
}
val pointStats: Array[(Int, (BV[Double], Long))] =
closest.reduceByKey{mergeContribs}.collectAsMap().toArray

// implement update rule
for (newP <- pointStats) {
// store old count and centroid
val oldCount = counts(newP._1)
val oldCentroid = centers(newP._1).toBreeze
// get new count and centroid
val newCount = newP._2._2
val newCentroid = newP._2._1 / newCount.toDouble
// compute the normalized scale factor that controls forgetting
val decayFactor = units match {
case "batches" => newCount / (a * oldCount + newCount)
case "points" => newCount / (math.pow(a, newCount) * oldCount + newCount)
}
// perform the update
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor
// store the new counts and centers
counts(newP._1) = oldCount + newCount
centers(newP._1) = Vectors.fromBreeze(updatedCentroid)
}

new StreamingKMeansModel(centers, counts)
}

}

@DeveloperApi
class StreamingKMeans(
var k: Int,
var a: Double,
var units: String) extends Logging {

protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)

def this() = this(2, 1.0, "batches")

def setK(k: Int): this.type = {
this.k = k
this
}

def setDecayFactor(a: Double): this.type = {
this.a = a
this
}

def setUnits(units: String): this.type = {
this.units = units
this
}

def setDecayFractionBatches(q: Double): this.type = {
this.a = math.log(1 - q) / math.log(0.5)
this.units = "batches"
this
}

def setDecayFractionPoints(q: Double, m: Double): this.type = {
this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m)
this.units = "points"
this
}

def setInitialCenters(initialCenters: Array[Vector]): this.type = {
val clusterCounts = Array.fill(this.k)(0).map(_.toLong)
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
this
}

def setRandomCenters(d: Int): this.type = {
val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray
val clusterCounts = Array.fill(0)(d).map(_.toLong)
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
this
}

def latestModel(): StreamingKMeansModel = {
model
}

def trainOn(data: DStream[Vector]) {
this.isInitialized
data.foreachRDD { (rdd, time) =>
model = model.update(rdd, this.a, this.units)
}
}

def predictOn(data: DStream[Vector]): DStream[Int] = {
this.isInitialized
data.map(model.predict)
}

def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
this.isInitialized
data.mapValues(model.predict)
}

def isInitialized: Boolean = {
if (Option(model.clusterCenters) == None) {
logError("Initial cluster centers must be set before starting predictions")
throw new IllegalArgumentException
} else {
true
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package org.apache.spark.mllib.clustering

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase

class StreamingKMeansSuite extends FunSuite with TestSuiteBase {

override def maxWaitTimeMillis = 30000

test("accuracy for single center and equivalence to grand average") {

// set parameters
val numBatches = 10
val numPoints = 50
val k = 1
val d = 5
val r = 0.1

// create model with one cluster
val model = new StreamingKMeans()
.setK(1)
.setDecayFactor(1.0)
.setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)))

// generate random data for kmeans
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)

// setup and run the model training
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
runStreams(ssc, numBatches, numBatches)

// estimated center should be close to true center
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)

// estimated center from streaming should exactly match the arithmetic mean of all data points
val grandMean = input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)

}

test("accuracy for two centers") {

val numBatches = 10
val numPoints = 5
val k = 2
val d = 5
val r = 0.1

// create model with two clusters
val model = new StreamingKMeans()
.setK(2)
.setDecayFactor(1.0)
.setInitialCenters(Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)))

// generate random data for kmeans
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)

// setup and run the model training
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
runStreams(ssc, numBatches, numBatches)

// check that estimated centers are close to true centers
// NOTE this depends on the initialization! allow for binary flip
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
assert(centers(1) ~== model.latestModel().clusterCenters(1) absTol 1E-1)

}

test("drifting with fractional decay in units of batches") {

val numBatches1 = 50
val numBatches2 = 50
val numPoints = 1
val q = 0.25
val k = 1
val d = 1
val r = 2.0

// create model with two clusters
val model = new StreamingKMeans()
.setK(1)
.setDecayFractionBatches(q)
.setInitialCenters(Array(Vectors.dense(0.0)))

// create two batches of data with different, pre-specified centers
// to simulate a transition from one cluster to another
val (input1, centers1) = StreamingKMeansDataGenerator(
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
val (input2, centers2) = StreamingKMeansDataGenerator(
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))

// store the history
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)

// setup and run the model training
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
// extract the center (in this case one-dimensional)
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
inputDStream.count()
})
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)

// check that the fraction of batches required to reach 50
// equals the setting of q, by finding the index of the first batch
// below 50 and comparing to total number of batches received
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
assert(fraction ~== q absTol 1E-1)

}

test("drifting with fractional decay in units of points") {

val numBatches1 = 50
val numBatches2 = 50
val numPoints = 10
val q = 0.25
val k = 1
val d = 1
val r = 2.0

// create model with two clusters
val model = new StreamingKMeans()
.setK(1)
.setDecayFractionPoints(q, numPoints)
.setInitialCenters(Array(Vectors.dense(0.0)))

// create two batches of data with different, pre-specified centers
// to simulate a transition from one cluster to another
val (input1, centers1) = StreamingKMeansDataGenerator(
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0)))
val (input2, centers2) = StreamingKMeansDataGenerator(
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0)))

// store the history
val history = new ArrayBuffer[Double](numBatches1 + numBatches2)

// setup and run the model training
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
// extract the center (in this case one-dimensional)
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0)))
inputDStream.count()
})
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2)

// check that the fraction of batches required to reach 50
// equals the setting of q, by finding the index of the first batch
// below 50 and comparing to total number of batches received
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex
assert(fraction ~== q absTol 1E-1)

}

def StreamingKMeansDataGenerator(
numPoints: Int,
numBatches: Int,
k: Int,
d: Int,
r: Double,
seed: Int,
initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
val rand = new Random(seed)
val centers = initCenters match {
case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
case _ => initCenters
}
val data = (0 until numBatches).map { i =>
(0 until numPoints).map { idx =>
val center = centers(idx % k)
Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
}
}
(data, centers)
}


}

0 comments on commit b93350f

Please sign in to comment.