Skip to content

Commit

Permalink
update KMeans to use fastSquaredDistance
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 11, 2014
1 parent f355411 commit 0ff8046
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 46 deletions.
83 changes: 58 additions & 25 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,25 @@ package org.apache.spark.mllib.clustering

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{DenseVector => BDV, Vector => BV, squaredDistance => breezeSquaredDistance}
import breeze.linalg.{DenseVector => BDV, Vector => BV, squaredDistance => breezeSquaredDistance,
norm => breezeNorm}

import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom

private[clustering]
case class BreezeVectorWithSquaredNorm(vector: BV[Double], squaredNorm: Double)
class BreezeVectorWithSquaredNorm(val vector: BV[Double], val squaredNorm: Double)
extends Serializable {
def this(vector: BV[Double]) = {
this(vector, {val nrm = breezeNorm(vector, 2.0); nrm * nrm})
}
/** Converts the vector to a dense vector. */
def toDense = new BreezeVectorWithSquaredNorm(vector.toDenseVector, squaredNorm)
}

/**
* K-means clustering with support for multiple parallel runs and a k-means++ like initialization
Expand Down Expand Up @@ -114,24 +122,31 @@ class KMeans private (
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Array[Double]]): KMeansModel = {
val breezeData = data.map(v => new BDV[Double](v).asInstanceOf[BV[Double]])
runBreeze(breezeData)
run(data.map(v => Vectors.dense(v)))
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector])(implicit d: DummyImplicit): KMeansModel = {
val breezeData = data.map(v => v.toBreeze)
runBreeze(breezeData)
val squaredNorms = data.map { v =>
val nrm = breezeNorm(v.toBreeze, 2.0)
nrm * nrm
}
squaredNorms.persist()
val breezeData = data.map(_.toBreeze).zip(squaredNorms).map { case (v, squaredNorm) =>
new BreezeVectorWithSquaredNorm(v, squaredNorm)
}
val model = runBreeze(breezeData)
squaredNorms.unpersist()
model
}

/**
* Implementation using Breeze.
*/
private def runBreeze(data: RDD[BV[Double]]): KMeansModel = {
// TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable
private def runBreeze(data: RDD[BreezeVectorWithSquaredNorm]): KMeansModel = {

val sc = data.sparkContext

Expand All @@ -149,7 +164,7 @@ class KMeans private (

// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (BDV[Double], Long)
type WeightedPoint = (BV[Double], Long)
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
(p1._1 += p2._1, p1._2 + p2._2)
}
Expand All @@ -161,15 +176,15 @@ class KMeans private (
val totalContribs = data.mapPartitions { points =>
val runs = activeCenters.length
val k = activeCenters(0).length
val dims = activeCenters(0)(0).length
val dims = activeCenters(0)(0).vector.length

val sums = Array.fill(runs, k)(BDV.zeros[Double](dims))
val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
val counts = Array.fill(runs, k)(0L)

for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) {
val (bestCenter, cost) = KMeans.findClosest(centers, point)
costAccums(runIndex) += cost
sums(runIndex)(bestCenter) += point
sums(runIndex)(bestCenter) += point.vector
counts(runIndex)(bestCenter) += 1
}

Expand All @@ -186,8 +201,8 @@ class KMeans private (
val (sum, count) = totalContribs((i, j))
if (count != 0) {
sum /= count.toDouble
val newCenter = sum
if (breezeSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
val newCenter = new BreezeVectorWithSquaredNorm(sum)
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
changed = true
}
centers(run)(j) = newCenter
Expand All @@ -206,18 +221,19 @@ class KMeans private (

val bestRun = costs.zipWithIndex.min._2
new KMeansModel(centers(bestRun).map { v =>
v.toArray
v.vector.toArray
})
}

/**
* Initialize `runs` sets of cluster centers at random.
*/
private def initRandom(data: RDD[BV[Double]]): Array[Array[BV[Double]]] = {
private def initRandom(data: RDD[BreezeVectorWithSquaredNorm])
: Array[Array[BreezeVectorWithSquaredNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
v.toDenseVector
new BreezeVectorWithSquaredNorm(v.vector.toDenseVector, v.squaredNorm)
}.toArray)
}

Expand All @@ -230,11 +246,12 @@ class KMeans private (
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
private def initKMeansParallel(data: RDD[BV[Double]]): Array[Array[BV[Double]]] = {
private def initKMeansParallel(data: RDD[BreezeVectorWithSquaredNorm])
: Array[Array[BreezeVectorWithSquaredNorm]] = {
// Initialize each run's center to a random point
val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDenseVector))
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's current centers
Expand All @@ -251,7 +268,7 @@ class KMeans private (
} yield (r, p)
}.collect()
for ((r, p) <- chosen) {
centers(r) += p.toDenseVector
centers(r) += p.toDense
}
}

Expand All @@ -262,7 +279,7 @@ class KMeans private (
for (r <- 0 until runs) yield ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
val myCenters = centers(r).toArray.asInstanceOf[Array[BV[Double]]]
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
}
Expand Down Expand Up @@ -375,9 +392,7 @@ object KMeans {
var bestIndex = 0
var i = 0
centers.foreach { center =>
val distance: Double = MLUtils.fastSquaredDistance(
center.vector, center.squaredNorm, point.vector, point.squaredNorm
)
val distance: Double = fastSquaredDistance(center, point)
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
Expand Down Expand Up @@ -407,6 +422,24 @@ object KMeans {
private[mllib] def pointCost(centers: TraversableOnce[BV[Double]], point: BV[Double]): Double =
findClosest(centers, point)._2

/**
* Returns the K-means cost of a given point against the given cluster centers.
*/
private[mllib] def pointCost(
centers: TraversableOnce[BreezeVectorWithSquaredNorm],
point: BreezeVectorWithSquaredNorm): Double =
findClosest(centers, point)._2

/**
* Returns the squared Euclidean distance between two vectors computed by
* [[org.apache.spark.mllib.util.MLUtils.fastSquaredDistance()]].
*/
private[clustering]
def fastSquaredDistance(v1: BreezeVectorWithSquaredNorm, v2: BreezeVectorWithSquaredNorm)
: Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.squaredNorm, v2.vector, v2.squaredNorm)
}

def main(args: Array[String]) {
if (args.length < 4) {
println("Usage: KMeans <master> <input_file> <k> <max_iterations> [<runs>]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering

import scala.util.Random

import breeze.linalg.{Vector => BV, DenseVector => BDV}
import breeze.linalg.{Vector => BV, DenseVector => BDV, norm => breezeNorm}

/**
* An utility object to run K-means locally. This is private to the ML package because it's used
Expand All @@ -34,9 +34,13 @@ private[mllib] object LocalKMeans {
k: Int,
maxIterations: Int
): Array[Array[Double]] = {
val breezePoints = points.map(v => new BDV[Double](v).asInstanceOf[BV[Double]])
val breezePoints = points.map { v =>
val bv = new BDV[Double](v)
val norm: Double = breezeNorm(bv, 2.0)
new BreezeVectorWithSquaredNorm(bv, norm * norm)
}
val breezeCenters = kMeansPlusPlus(seed, breezePoints, weights, k, maxIterations)
breezeCenters.map(_.toArray)
breezeCenters.map(_.vector.toArray)
}

/**
Expand All @@ -45,20 +49,20 @@ private[mllib] object LocalKMeans {
*/
def kMeansPlusPlus(
seed: Int,
points: Array[BV[Double]],
points: Array[BreezeVectorWithSquaredNorm],
weights: Array[Double],
k: Int,
maxIterations: Int
)(implicit d: DummyImplicit): Array[BV[Double]] = {
)(implicit d: DummyImplicit): Array[BreezeVectorWithSquaredNorm] = {
val rand = new Random(seed)
val dimensions = points(0).length
val centers = new Array[BV[Double]](k)
val dimensions = points(0).vector.length
val centers = new Array[BreezeVectorWithSquaredNorm](k)

// Initialize centers by sampling using the k-means++ procedure.
centers(0) = (pickWeighted(rand, points, weights)).toDenseVector
centers(0) = pickWeighted(rand, points, weights).toDense
for (i <- 1 until k) {
// Pick the next center with a probability proportional to cost under current centers
val curCenters = centers.slice(0, i)
val curCenters = centers.view.take(i)
val sum = points.zip(weights).map { case (p, w) =>
w * KMeans.pointCost(curCenters, p)
}.sum
Expand All @@ -69,7 +73,7 @@ private[mllib] object LocalKMeans {
cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
j += 1
}
centers(i) = points(j-1).toDenseVector
centers(i) = points(j-1).toDense
}

// Run up to maxIterations iterations of Lloyd's algorithm
Expand All @@ -79,12 +83,12 @@ private[mllib] object LocalKMeans {
while (moved && iteration < maxIterations) {
moved = false
val sums = Array.fill(k)(
new BDV[Double](new Array[Double](dimensions)).asInstanceOf[BV[Double]]
BDV.zeros[Double](dimensions).asInstanceOf[BV[Double]]
)
val counts = Array.fill(k)(0.0)
for ((p, i) <- points.zipWithIndex) {
val index = KMeans.findClosest(centers, p)._1
breeze.linalg.axpy(weights(i), p, sums(index))
breeze.linalg.axpy(weights(i), p.vector, sums(index))
counts(index) += weights(i)
if (index != oldClosest(i)) {
moved = true
Expand All @@ -95,10 +99,10 @@ private[mllib] object LocalKMeans {
for (i <- 0 until k) {
if (counts(i) == 0.0) {
// Assign center to a random point
centers(i) = points(rand.nextInt(points.length)).toDenseVector
centers(i) = points(rand.nextInt(points.length)).toDense
} else {
sums(i) /= counts(i)
centers(i) = sums(i)
centers(i) = new BreezeVectorWithSquaredNorm(sums(i))
}
}
iteration += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,20 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
}

test("single cluster with sparse data") {
val n = 1000
val smallData = Array(
Vectors.sparse(n, Seq((0, 1.0), (1, 2.0), (2, 6.0))),
Vectors.sparse(n, Seq((0, 1.0), (1, 3.0))),
Vectors.sparse(n, Seq((0, 1.0), (1, 4.0), (2, 6.0)))
)
val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4)
val n = 1000000
val data = sc.parallelize((1 to 100).flatMap { i =>
val x = i / 1000.0
Array(
Vectors.sparse(n, Seq((0, 1.0 + x), (1, 2.0), (2, 6.0))),
Vectors.sparse(n, Seq((0, 1.0 - x), (1, 2.0), (2, 6.0))),
Vectors.sparse(n, Seq((0, 1.0), (1, 3.0 + x))),
Vectors.sparse(n, Seq((0, 1.0), (1, 3.0 - x))),
Vectors.sparse(n, Seq((0, 1.0), (1, 4.0), (2, 6.0 + x))),
Vectors.sparse(n, Seq((0, 1.0), (1, 4.0), (2, 6.0 - x)))
)
}, 4)

data.persist()

// No matter how many runs or iterations we use, we should get one cluster,
// centered at the mean of the points
Expand Down Expand Up @@ -167,6 +174,8 @@ class KMeansSuite extends FunSuite with LocalSparkContext {

model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
assertSetsEqual(model.clusterCenters, Array(center))

data.unpersist()
}

test("k-means|| initialization") {
Expand Down

0 comments on commit 0ff8046

Please sign in to comment.