forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Used trainOn and predictOn pattern, similar to StreamingLinearAlgorithm - Decay factor can be set explicitly, or via fractional decay parameters expressed in units of number of batches, or number of points - Unit tests for basic functionality and decay settings
- Loading branch information
1 parent
31f0b07
commit b93350f
Showing
2 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
143 changes: 143 additions & 0 deletions
143
mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
package org.apache.spark.mllib.clustering | ||
|
||
import breeze.linalg.{Vector => BV} | ||
|
||
import scala.reflect.ClassTag | ||
import scala.util.Random._ | ||
|
||
import org.apache.spark.annotation.DeveloperApi | ||
import org.apache.spark.Logging | ||
import org.apache.spark.mllib.linalg.{Vectors, Vector} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.streaming.dstream.DStream | ||
import org.apache.spark.streaming.StreamingContext._ | ||
|
||
@DeveloperApi | ||
class StreamingKMeansModel( | ||
override val clusterCenters: Array[Vector], | ||
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) { | ||
|
||
/** do a sequential KMeans update on a batch of data **/ | ||
def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = { | ||
|
||
val centers = clusterCenters | ||
val counts = clusterCounts | ||
|
||
// find nearest cluster to each point | ||
val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong))) | ||
|
||
// get sums and counts for updating each cluster | ||
type WeightedPoint = (BV[Double], Long) | ||
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = { | ||
(p1._1 += p2._1, p1._2 + p2._2) | ||
} | ||
val pointStats: Array[(Int, (BV[Double], Long))] = | ||
closest.reduceByKey{mergeContribs}.collectAsMap().toArray | ||
|
||
// implement update rule | ||
for (newP <- pointStats) { | ||
// store old count and centroid | ||
val oldCount = counts(newP._1) | ||
val oldCentroid = centers(newP._1).toBreeze | ||
// get new count and centroid | ||
val newCount = newP._2._2 | ||
val newCentroid = newP._2._1 / newCount.toDouble | ||
// compute the normalized scale factor that controls forgetting | ||
val decayFactor = units match { | ||
case "batches" => newCount / (a * oldCount + newCount) | ||
case "points" => newCount / (math.pow(a, newCount) * oldCount + newCount) | ||
} | ||
// perform the update | ||
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor | ||
// store the new counts and centers | ||
counts(newP._1) = oldCount + newCount | ||
centers(newP._1) = Vectors.fromBreeze(updatedCentroid) | ||
} | ||
|
||
new StreamingKMeansModel(centers, counts) | ||
} | ||
|
||
} | ||
|
||
@DeveloperApi | ||
class StreamingKMeans( | ||
var k: Int, | ||
var a: Double, | ||
var units: String) extends Logging { | ||
|
||
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) | ||
|
||
def this() = this(2, 1.0, "batches") | ||
|
||
def setK(k: Int): this.type = { | ||
this.k = k | ||
this | ||
} | ||
|
||
def setDecayFactor(a: Double): this.type = { | ||
this.a = a | ||
this | ||
} | ||
|
||
def setUnits(units: String): this.type = { | ||
this.units = units | ||
this | ||
} | ||
|
||
def setDecayFractionBatches(q: Double): this.type = { | ||
this.a = math.log(1 - q) / math.log(0.5) | ||
this.units = "batches" | ||
this | ||
} | ||
|
||
def setDecayFractionPoints(q: Double, m: Double): this.type = { | ||
this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m) | ||
this.units = "points" | ||
this | ||
} | ||
|
||
def setInitialCenters(initialCenters: Array[Vector]): this.type = { | ||
val clusterCounts = Array.fill(this.k)(0).map(_.toLong) | ||
this.model = new StreamingKMeansModel(initialCenters, clusterCounts) | ||
this | ||
} | ||
|
||
def setRandomCenters(d: Int): this.type = { | ||
val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray | ||
val clusterCounts = Array.fill(0)(d).map(_.toLong) | ||
this.model = new StreamingKMeansModel(initialCenters, clusterCounts) | ||
this | ||
} | ||
|
||
def latestModel(): StreamingKMeansModel = { | ||
model | ||
} | ||
|
||
def trainOn(data: DStream[Vector]) { | ||
this.isInitialized | ||
data.foreachRDD { (rdd, time) => | ||
model = model.update(rdd, this.a, this.units) | ||
} | ||
} | ||
|
||
def predictOn(data: DStream[Vector]): DStream[Int] = { | ||
this.isInitialized | ||
data.map(model.predict) | ||
} | ||
|
||
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { | ||
this.isInitialized | ||
data.mapValues(model.predict) | ||
} | ||
|
||
def isInitialized: Boolean = { | ||
if (Option(model.clusterCenters) == None) { | ||
logError("Initial cluster centers must be set before starting predictions") | ||
throw new IllegalArgumentException | ||
} else { | ||
true | ||
} | ||
} | ||
|
||
} |
194 changes: 194 additions & 0 deletions
194
mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
package org.apache.spark.mllib.clustering | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
import scala.util.Random | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.linalg.{Vectors, Vector} | ||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.streaming.dstream.DStream | ||
import org.apache.spark.streaming.TestSuiteBase | ||
|
||
class StreamingKMeansSuite extends FunSuite with TestSuiteBase { | ||
|
||
override def maxWaitTimeMillis = 30000 | ||
|
||
test("accuracy for single center and equivalence to grand average") { | ||
|
||
// set parameters | ||
val numBatches = 10 | ||
val numPoints = 50 | ||
val k = 1 | ||
val d = 5 | ||
val r = 0.1 | ||
|
||
// create model with one cluster | ||
val model = new StreamingKMeans() | ||
.setK(1) | ||
.setDecayFactor(1.0) | ||
.setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0))) | ||
|
||
// generate random data for kmeans | ||
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) | ||
|
||
// setup and run the model training | ||
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { | ||
model.trainOn(inputDStream) | ||
inputDStream.count() | ||
}) | ||
runStreams(ssc, numBatches, numBatches) | ||
|
||
// estimated center should be close to true center | ||
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) | ||
|
||
// estimated center from streaming should exactly match the arithmetic mean of all data points | ||
val grandMean = input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble | ||
assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) | ||
|
||
} | ||
|
||
test("accuracy for two centers") { | ||
|
||
val numBatches = 10 | ||
val numPoints = 5 | ||
val k = 2 | ||
val d = 5 | ||
val r = 0.1 | ||
|
||
// create model with two clusters | ||
val model = new StreamingKMeans() | ||
.setK(2) | ||
.setDecayFactor(1.0) | ||
.setInitialCenters(Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), | ||
Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1))) | ||
|
||
// generate random data for kmeans | ||
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) | ||
|
||
// setup and run the model training | ||
val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { | ||
model.trainOn(inputDStream) | ||
inputDStream.count() | ||
}) | ||
runStreams(ssc, numBatches, numBatches) | ||
|
||
// check that estimated centers are close to true centers | ||
// NOTE this depends on the initialization! allow for binary flip | ||
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) | ||
assert(centers(1) ~== model.latestModel().clusterCenters(1) absTol 1E-1) | ||
|
||
} | ||
|
||
test("drifting with fractional decay in units of batches") { | ||
|
||
val numBatches1 = 50 | ||
val numBatches2 = 50 | ||
val numPoints = 1 | ||
val q = 0.25 | ||
val k = 1 | ||
val d = 1 | ||
val r = 2.0 | ||
|
||
// create model with two clusters | ||
val model = new StreamingKMeans() | ||
.setK(1) | ||
.setDecayFractionBatches(q) | ||
.setInitialCenters(Array(Vectors.dense(0.0))) | ||
|
||
// create two batches of data with different, pre-specified centers | ||
// to simulate a transition from one cluster to another | ||
val (input1, centers1) = StreamingKMeansDataGenerator( | ||
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0))) | ||
val (input2, centers2) = StreamingKMeansDataGenerator( | ||
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0))) | ||
|
||
// store the history | ||
val history = new ArrayBuffer[Double](numBatches1 + numBatches2) | ||
|
||
// setup and run the model training | ||
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => { | ||
model.trainOn(inputDStream) | ||
// extract the center (in this case one-dimensional) | ||
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0))) | ||
inputDStream.count() | ||
}) | ||
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2) | ||
|
||
// check that the fraction of batches required to reach 50 | ||
// equals the setting of q, by finding the index of the first batch | ||
// below 50 and comparing to total number of batches received | ||
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble | ||
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex | ||
assert(fraction ~== q absTol 1E-1) | ||
|
||
} | ||
|
||
test("drifting with fractional decay in units of points") { | ||
|
||
val numBatches1 = 50 | ||
val numBatches2 = 50 | ||
val numPoints = 10 | ||
val q = 0.25 | ||
val k = 1 | ||
val d = 1 | ||
val r = 2.0 | ||
|
||
// create model with two clusters | ||
val model = new StreamingKMeans() | ||
.setK(1) | ||
.setDecayFractionPoints(q, numPoints) | ||
.setInitialCenters(Array(Vectors.dense(0.0))) | ||
|
||
// create two batches of data with different, pre-specified centers | ||
// to simulate a transition from one cluster to another | ||
val (input1, centers1) = StreamingKMeansDataGenerator( | ||
numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0))) | ||
val (input2, centers2) = StreamingKMeansDataGenerator( | ||
numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0))) | ||
|
||
// store the history | ||
val history = new ArrayBuffer[Double](numBatches1 + numBatches2) | ||
|
||
// setup and run the model training | ||
val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => { | ||
model.trainOn(inputDStream) | ||
// extract the center (in this case one-dimensional) | ||
inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0))) | ||
inputDStream.count() | ||
}) | ||
runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2) | ||
|
||
// check that the fraction of batches required to reach 50 | ||
// equals the setting of q, by finding the index of the first batch | ||
// below 50 and comparing to total number of batches received | ||
val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble | ||
val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex | ||
assert(fraction ~== q absTol 1E-1) | ||
|
||
} | ||
|
||
def StreamingKMeansDataGenerator( | ||
numPoints: Int, | ||
numBatches: Int, | ||
k: Int, | ||
d: Int, | ||
r: Double, | ||
seed: Int, | ||
initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = { | ||
val rand = new Random(seed) | ||
val centers = initCenters match { | ||
case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) | ||
case _ => initCenters | ||
} | ||
val data = (0 until numBatches).map { i => | ||
(0 until numPoints).map { idx => | ||
val center = centers(idx % k) | ||
Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r)) | ||
} | ||
} | ||
(data, centers) | ||
} | ||
|
||
|
||
} |