diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d0b258b1678b6..f9bb0d9989de7 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -174,7 +174,13 @@ to the cluster thus far, `$x_t$` is the new cluster center from the current batc is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; with `$\alpha$=0` only the most recent data will be used. This is analogous to an -exponentially-weighted moving average. +exponentially-weighted moving average. + +The decay can be specified using a `halfLife` parameter, which determines the +correct decay factor `a` such that, for data acquired +at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. +The unit of time can be specified either as `batches` or `points` and the update rule +will be adjusted accordingly. ### Examples diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3c8a84c3db104..3a6451118ca5e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -39,28 +39,28 @@ import org.apache.spark.util.Utils * * The update algorithm uses the "mini-batch" KMeans rule, * generalized to incorporate forgetfullness (i.e. decay). - * The basic update rule (for each cluster) is: + * The update rule (for each cluster) is: * - * c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t] - * n_t+t = n_t + m_t + * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t * a + m_t * * Where c_t is the previously estimated centroid for that cluster, * n_t is the number of points assigned to it thus far, x_t is the centroid * estimated on the current batch, and m_t is the number of points assigned * to that centroid in the current batch. * - * This update rule is modified with a decay factor 'a' that scales - * the contribution of the clusters as estimated thus far. - * If a=1, all batches are weighted equally. If a=0, new centroids + * The decay factor 'a' scales the contribution of the clusters as estimated thus far, + * by applying a as a discount weighting on the current point when evaluating + * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids * are determined entirely by recent data. Lower values correspond to * more forgetting. * - * Decay can optionally be specified as a decay fraction 'q', - * which corresponds to the fraction of batches (or points) - * after which the past will be reduced to a contribution of 0.5. - * This decay fraction can be specified in units of 'points' or 'batches'. - * if 'batches', behavior will be independent of the number of points per batch; - * if 'points', the expected number of points per batch must be specified. + * Decay can optionally be specified by a half life and associated + * time unit. The time unit can either be a batch of data or a single + * data point. Considering data arrived at time t, the half life h is defined + * such that at time t + h the discount applied to the data from t is 0.5. + * The definition remains the same whether the time unit is given + * as batches or points. * */ @DeveloperApi @@ -69,7 +69,7 @@ class StreamingKMeansModel( val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging { /** Perform a k-means update on a batch of data. */ - def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = { + def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { val centers = clusterCenters val counts = clusterCounts @@ -94,12 +94,12 @@ class StreamingKMeansModel( val newCount = count val newCentroid = mean / newCount.toDouble // compute the normalized scale factor that controls forgetting - val decayFactor = units match { - case "batches" => newCount / (a * oldCount + newCount) - case "points" => newCount / (math.pow(a, newCount) * oldCount + newCount) + val lambda = timeUnit match { + case "batches" => newCount / (decayFactor * oldCount + newCount) + case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount) } // perform the update - val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * decayFactor + val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda // store the new counts and centers counts(label) = oldCount + newCount centers(label) = Vectors.fromBreeze(updatedCentroid) @@ -134,8 +134,8 @@ class StreamingKMeansModel( @DeveloperApi class StreamingKMeans( var k: Int, - var a: Double, - var units: String) extends Logging { + var decayFactor: Double, + var timeUnit: String) extends Logging { protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) @@ -149,30 +149,18 @@ class StreamingKMeans( /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { - this.a = a + this.decayFactor = decayFactor this } - /** Set the decay units for forgetful algorithms ("batches" or "points"). */ - def setUnits(units: String): this.type = { - if (units != "batches" && units != "points") { - throw new IllegalArgumentException("Invalid units for decay: " + units) + /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + if (timeUnit != "batches" && timeUnit != "points") { + throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) } - this.units = units - this - } - - /** Set decay fraction in units of batches. */ - def setDecayFractionBatches(q: Double): this.type = { - this.a = math.log(1 - q) / math.log(0.5) - this.units = "batches" - this - } - - /** Set decay fraction in units of points. Must specify expected number of points per batch. */ - def setDecayFractionPoints(q: Double, m: Double): this.type = { - this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m) - this.units = "points" + this.decayFactor = math.exp(math.log(0.5) / halfLife) + logInfo("Setting decay factor to: %g ".format (this.decayFactor)) + this.timeUnit = timeUnit this } @@ -216,7 +204,7 @@ class StreamingKMeans( def trainOn(data: DStream[Vector]) { this.assertInitialized() data.foreachRDD { (rdd, time) => - model = model.update(rdd, this.a, this.units) + model = model.update(rdd, this.decayFactor, this.timeUnit) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 5c23b04961df2..de79c7026a696 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.clustering -import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.scalatest.FunSuite @@ -98,94 +97,6 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { } - test("drifting with fractional decay in units of batches") { - - val numBatches1 = 50 - val numBatches2 = 50 - val numPoints = 1 - val q = 0.25 - val k = 1 - val d = 1 - val r = 2.0 - - // create model with two clusters - val model = new StreamingKMeans() - .setK(1) - .setDecayFractionBatches(q) - .setInitialCenters(Array(Vectors.dense(0.0))) - - // create two batches of data with different, pre-specified centers - // to simulate a transition from one cluster to another - val (input1, centers1) = StreamingKMeansDataGenerator( - numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0))) - val (input2, centers2) = StreamingKMeansDataGenerator( - numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0))) - - // store the history - val history = new ArrayBuffer[Double](numBatches1 + numBatches2) - - // setup and run the model training - val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => { - model.trainOn(inputDStream) - // extract the center (in this case one-dimensional) - inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0))) - inputDStream.count() - }) - runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2) - - // check that the fraction of batches required to reach 50 - // equals the setting of q, by finding the index of the first batch - // below 50 and comparing to total number of batches received - val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble - val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex - assert(fraction ~== q absTol 1E-1) - - } - - test("drifting with fractional decay in units of points") { - - val numBatches1 = 50 - val numBatches2 = 50 - val numPoints = 10 - val q = 0.25 - val k = 1 - val d = 1 - val r = 2.0 - - // create model with two clusters - val model = new StreamingKMeans() - .setK(1) - .setDecayFractionPoints(q, numPoints) - .setInitialCenters(Array(Vectors.dense(0.0))) - - // create two batches of data with different, pre-specified centers - // to simulate a transition from one cluster to another - val (input1, centers1) = StreamingKMeansDataGenerator( - numPoints, numBatches1, k, d, r, 42, initCenters = Array(Vectors.dense(100.0))) - val (input2, centers2) = StreamingKMeansDataGenerator( - numPoints, numBatches2, k, d, r, 84, initCenters = Array(Vectors.dense(0.0))) - - // store the history - val history = new ArrayBuffer[Double](numBatches1 + numBatches2) - - // setup and run the model training - val ssc = setupStreams(input1 ++ input2, (inputDStream: DStream[Vector]) => { - model.trainOn(inputDStream) - // extract the center (in this case one-dimensional) - inputDStream.foreachRDD(x => history.append(model.latestModel().clusterCenters(0)(0))) - inputDStream.count() - }) - runStreams(ssc, numBatches1 + numBatches2, numBatches1 + numBatches2) - - // check that the fraction of batches required to reach 50 - // equals the setting of q, by finding the index of the first batch - // below 50 and comparing to total number of batches received - val halvedIndex = history.zipWithIndex.filter( x => x._1 < 50)(0)._2.toDouble - val fraction = (halvedIndex - numBatches1.toDouble) / halvedIndex - assert(fraction ~== q absTol 1E-1) - - } - def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int,