Skip to content

Commit

Permalink
tuned the KMeans code: changed some for loops to while, use view to a…
Browse files Browse the repository at this point in the history
…void copying arrays

added some log messages
  • Loading branch information
mengxr committed Mar 12, 2014
1 parent 0ff8046 commit 87bc755
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 26 deletions.
59 changes: 43 additions & 16 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,26 @@ class KMeans private (

val sc = data.sparkContext

val initStartTime = System.nanoTime()

val centers = if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}

val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")

val active = Array.fill(runs)(true)
val costs = Array.fill(runs)(0.0)

var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
var iteration = 0

val iterationStartTime = System.nanoTime()

// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (BV[Double], Long)
Expand All @@ -181,11 +189,13 @@ class KMeans private (
val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
val counts = Array.fill(runs, k)(0L)

for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) {
val (bestCenter, cost) = KMeans.findClosest(centers, point)
costAccums(runIndex) += cost
sums(runIndex)(bestCenter) += point.vector
counts(runIndex)(bestCenter) += 1
points.foreach { point =>
activeRuns.foreach { r =>
val (bestCenter, cost) = KMeans.findClosest(centers(r), point)
costAccums(r) += cost
sums(r)(bestCenter) += point.vector
counts(r)(bestCenter) += 1
}
}

val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
Expand All @@ -195,9 +205,10 @@ class KMeans private (
}.reduceByKey(mergeContribs).collectAsMap()

// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
for ((run, i) <- activeRuns.view.zipWithIndex) {
var changed = false
for (j <- 0 until k) {
var j = 0
while (j < k) {
val (sum, count) = totalContribs((i, j))
if (count != 0) {
sum /= count.toDouble
Expand All @@ -207,6 +218,7 @@ class KMeans private (
}
centers(run)(j) = newCenter
}
j += 1
}
if (!changed) {
active(run) = false
Expand All @@ -219,6 +231,15 @@ class KMeans private (
iteration += 1
}

val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.")

if (iteration == maxIterations) {
logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"Kmeans converged in $iteration iterations.")
}

val bestRun = costs.zipWithIndex.min._2
new KMeansModel(centers(bestRun).map { v =>
v.vector.toArray
Expand Down Expand Up @@ -255,28 +276,34 @@ class KMeans private (

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's current centers
for (step <- 0 until initializationSteps) {
var step = 0
while (step < initializationSteps) {
val sumCosts = data.flatMap { point =>
for (r <- 0 until runs) yield (r, KMeans.pointCost(centers(r), point))
(0 until runs).map { r =>
(r, KMeans.pointCost(centers(r), point))
}
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
for {
p <- points
r <- 0 until runs
if rand.nextDouble() < KMeans.pointCost(centers(r), p) * 2 * k / sumCosts(r)
} yield (r, p)
points.flatMap { p =>
(0 until runs).filter { r =>
rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
}.map((_, p))
}
}.collect()
for ((r, p) <- chosen) {
chosen.foreach { case (r, p) =>
centers(r) += p.toDense
}
step += 1
}

// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val weightMap = data.flatMap { p =>
for (r <- 0 until runs) yield ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
(0 until runs).map { r =>
((r, KMeans.findClosest(centers(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
val myCenters = centers(r).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import scala.util.Random

import breeze.linalg.{Vector => BV, DenseVector => BDV, norm => breezeNorm}

import org.apache.spark.Logging

/**
* An utility object to run K-means locally. This is private to the ML package because it's used
* in the initialization of KMeans but not meant to be publicly exposed.
*/
private[mllib] object LocalKMeans {
private[mllib] object LocalKMeans extends Logging {

def kMeansPlusPlus(
seed: Int,
Expand Down Expand Up @@ -63,7 +65,7 @@ private[mllib] object LocalKMeans {
for (i <- 1 until k) {
// Pick the next center with a probability proportional to cost under current centers
val curCenters = centers.view.take(i)
val sum = points.zip(weights).map { case (p, w) =>
val sum = points.view.zip(weights).map { case (p, w) =>
w * KMeans.pointCost(curCenters, p)
}.sum
val r = rand.nextDouble() * sum
Expand All @@ -82,32 +84,43 @@ private[mllib] object LocalKMeans {
var moved = true
while (moved && iteration < maxIterations) {
moved = false
val counts = Array.fill(k)(0.0)
val sums = Array.fill(k)(
BDV.zeros[Double](dimensions).asInstanceOf[BV[Double]]
)
val counts = Array.fill(k)(0.0)
for ((p, i) <- points.zipWithIndex) {
var i = 0
while (i < points.length) {
val p = points(i)
val index = KMeans.findClosest(centers, p)._1
breeze.linalg.axpy(weights(i), p.vector, sums(index))
counts(index) += weights(i)
if (index != oldClosest(i)) {
moved = true
oldClosest(i) = index
}
i += 1
}
// Update centers
for (i <- 0 until k) {
if (counts(i) == 0.0) {
var j = 0
while (j < k) {
if (counts(j) == 0.0) {
// Assign center to a random point
centers(i) = points(rand.nextInt(points.length)).toDense
centers(j) = points(rand.nextInt(points.length)).toDense
} else {
sums(i) /= counts(i)
centers(i) = new BreezeVectorWithSquaredNorm(sums(i))
sums(j) /= counts(j)
centers(j) = new BreezeVectorWithSquaredNorm(sums(j))
}
j += 1
}
iteration += 1
}

if (iteration == maxIterations) {
logInfo(s"Local KMeans++ reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"Local KMeans++ converged in $iteration iterations.")
}

centers
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
}

test("single cluster with sparse data") {
val n = 1000000

val n = 10000
val data = sc.parallelize((1 to 100).flatMap { i =>
val x = i / 1000.0
Array(
Expand Down

0 comments on commit 87bc755

Please sign in to comment.