Skip to content

Commit

Permalink
[SPARK-4708][MLLib] Make k-mean runs two/three times faster with dens…
Browse files Browse the repository at this point in the history
…e/sparse sample

Note that the usage of `breezeSquaredDistance` in
`org.apache.spark.mllib.util.MLUtils.fastSquaredDistance`
is in the critical path, and `breezeSquaredDistance` is slow.
We should replace it with our own implementation.

Here is the benchmark against mnist8m dataset.

Before
DenseVector: 70.04secs
SparseVector: 59.05secs

With this PR
DenseVector: 30.58secs
SparseVector: 21.14secs

Author: DB Tsai <dbtsai@alpinenow.com>

Closes apache#3565 from dbtsai/kmean and squashes the following commits:

08bc068 [DB Tsai] restyle
de24662 [DB Tsai] address feedback
b185a77 [DB Tsai] cleanup
4554ddd [DB Tsai] first commit
  • Loading branch information
DB Tsai authored and mengxr committed Dec 3, 2014
1 parent 4ac2151 commit 7fc49ed
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 68 deletions.
67 changes: 33 additions & 34 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ package org.apache.spark.mllib.clustering

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{DenseVector => BDV, Vector => BV}

import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -127,10 +126,10 @@ class KMeans private (
// Compute squared norms and cache them.
val norms = data.map(Vectors.norm(_, 2.0))
norms.persist()
val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
new BreezeVectorWithNorm(v, norm)
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
}
val model = runBreeze(breezeData)
val model = runAlgorithm(zippedData)
norms.unpersist()

// Warn at the end of the run as well, for increased visibility.
Expand All @@ -142,9 +141,9 @@ class KMeans private (
}

/**
* Implementation of K-Means using breeze.
* Implementation of K-Means algorithm.
*/
private def runBreeze(data: RDD[BreezeVectorWithNorm]): KMeansModel = {
private def runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {

val sc = data.sparkContext

Expand All @@ -170,9 +169,10 @@ class KMeans private (

// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (BV[Double], Long)
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
(p1._1 += p2._1, p1._2 + p2._2)
type WeightedPoint = (Vector, Long)
def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
axpy(1.0, x._1, y._1)
(y._1, x._2 + y._2)
}

val activeCenters = activeRuns.map(r => centers(r)).toArray
Expand All @@ -185,16 +185,17 @@ class KMeans private (
val thisActiveCenters = bcActiveCenters.value
val runs = thisActiveCenters.length
val k = thisActiveCenters(0).length
val dims = thisActiveCenters(0)(0).vector.length
val dims = thisActiveCenters(0)(0).vector.size

val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
val sums = Array.fill(runs, k)(Vectors.zeros(dims))
val counts = Array.fill(runs, k)(0L)

points.foreach { point =>
(0 until runs).foreach { i =>
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i) += cost
sums(i)(bestCenter) += point.vector
val sum = sums(i)(bestCenter)
axpy(1.0, point.vector, sum)
counts(i)(bestCenter) += 1
}
}
Expand All @@ -212,8 +213,8 @@ class KMeans private (
while (j < k) {
val (sum, count) = totalContribs((i, j))
if (count != 0) {
sum /= count.toDouble
val newCenter = new BreezeVectorWithNorm(sum)
scal(1.0 / count, sum)
val newCenter = new VectorWithNorm(sum)
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
changed = true
}
Expand Down Expand Up @@ -245,18 +246,18 @@ class KMeans private (

logInfo(s"The cost for the best run is $minCost.")

new KMeansModel(centers(bestRun).map(c => Vectors.fromBreeze(c.vector)))
new KMeansModel(centers(bestRun).map(_.vector))
}

/**
* Initialize `runs` sets of cluster centers at random.
*/
private def initRandom(data: RDD[BreezeVectorWithNorm])
: Array[Array[BreezeVectorWithNorm]] = {
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new BreezeVectorWithNorm(v.vector.toDenseVector, v.norm)
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
}

Expand All @@ -269,8 +270,8 @@ class KMeans private (
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
private def initKMeansParallel(data: RDD[BreezeVectorWithNorm])
: Array[Array[BreezeVectorWithNorm]] = {
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Initialize each run's center to a random point
val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
Expand Down Expand Up @@ -376,8 +377,8 @@ object KMeans {
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
private[mllib] def findClosest(
centers: TraversableOnce[BreezeVectorWithNorm],
point: BreezeVectorWithNorm): (Int, Double) = {
centers: TraversableOnce[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
Expand All @@ -402,35 +403,33 @@ object KMeans {
* Returns the K-means cost of a given point against the given cluster centers.
*/
private[mllib] def pointCost(
centers: TraversableOnce[BreezeVectorWithNorm],
point: BreezeVectorWithNorm): Double =
centers: TraversableOnce[VectorWithNorm],
point: VectorWithNorm): Double =
findClosest(centers, point)._2

/**
* Returns the squared Euclidean distance between two vectors computed by
* [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
*/
private[clustering] def fastSquaredDistance(
v1: BreezeVectorWithNorm,
v2: BreezeVectorWithNorm): Double = {
v1: VectorWithNorm,
v2: VectorWithNorm): Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}
}

/**
* A breeze vector with its norm for fast distance computation.
* A vector with its norm for fast distance computation.
*
* @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
*/
private[clustering]
class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {

def this(vector: BV[Double]) = this(vector, Vectors.norm(Vectors.fromBreeze(vector), 2.0))
class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable {

def this(array: Array[Double]) = this(new BDV[Double](array))
def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))

def this(v: Vector) = this(v.toBreeze)
def this(array: Array[Double]) = this(Vectors.dense(array))

/** Converts the vector to a dense vector. */
def toDense = new BreezeVectorWithNorm(vector.toDenseVector, norm)
def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {

/** Returns the cluster index that a given point belongs to. */
def predict(point: Vector): Int = {
KMeans.findClosest(clusterCentersWithNorm, new BreezeVectorWithNorm(point))._1
KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
}

/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
}

/** Maps given points to their cluster indices. */
Expand All @@ -53,9 +53,9 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
}

private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
clusterCenters.map(new BreezeVectorWithNorm(_))
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.mllib.clustering

import scala.util.Random

import breeze.linalg.{Vector => BV, DenseVector => BDV, norm => breezeNorm}

import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}

/**
* An utility object to run K-means locally. This is private to the ML package because it's used
Expand All @@ -35,14 +35,14 @@ private[mllib] object LocalKMeans extends Logging {
*/
def kMeansPlusPlus(
seed: Int,
points: Array[BreezeVectorWithNorm],
points: Array[VectorWithNorm],
weights: Array[Double],
k: Int,
maxIterations: Int
): Array[BreezeVectorWithNorm] = {
): Array[VectorWithNorm] = {
val rand = new Random(seed)
val dimensions = points(0).vector.length
val centers = new Array[BreezeVectorWithNorm](k)
val dimensions = points(0).vector.size
val centers = new Array[VectorWithNorm](k)

// Initialize centers by sampling using the k-means++ procedure.
centers(0) = pickWeighted(rand, points, weights).toDense
Expand Down Expand Up @@ -75,14 +75,12 @@ private[mllib] object LocalKMeans extends Logging {
while (moved && iteration < maxIterations) {
moved = false
val counts = Array.fill(k)(0.0)
val sums = Array.fill(k)(
BDV.zeros[Double](dimensions).asInstanceOf[BV[Double]]
)
val sums = Array.fill(k)(Vectors.zeros(dimensions))
var i = 0
while (i < points.length) {
val p = points(i)
val index = KMeans.findClosest(centers, p)._1
breeze.linalg.axpy(weights(i), p.vector, sums(index))
axpy(weights(i), p.vector, sums(index))
counts(index) += weights(i)
if (index != oldClosest(i)) {
moved = true
Expand All @@ -97,8 +95,8 @@ private[mllib] object LocalKMeans extends Logging {
// Assign center to a random point
centers(j) = points(rand.nextInt(points.length)).toDense
} else {
sums(j) /= counts(j)
centers(j) = new BreezeVectorWithNorm(sums(j))
scal(1.0 / counts(j), sums(j))
centers(j) = new VectorWithNorm(sums(j))
}
j += 1
}
Expand Down
26 changes: 15 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util

import scala.reflect.ClassTag

import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV,
squaredDistance => breezeSquaredDistance}

import org.apache.spark.annotation.Experimental
Expand All @@ -28,7 +28,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
import org.apache.spark.util.random.BernoulliCellSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream.DStream
Expand Down Expand Up @@ -281,9 +282,9 @@ object MLUtils {
* @return squared distance between v1 and v2 within the specified precision
*/
private[mllib] def fastSquaredDistance(
v1: BV[Double],
v1: Vector,
norm1: Double,
v2: BV[Double],
v2: Vector,
norm2: Double,
precision: Double = 1e-6): Double = {
val n = v1.size
Expand All @@ -306,16 +307,19 @@ object MLUtils {
*/
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)
} else if (v1.isInstanceOf[BSV[Double]] || v2.isInstanceOf[BSV[Double]]) {
val dot = v1.dot(v2)
sqDist = math.max(sumSquaredNorm - 2.0 * dot, 0.0)
val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dot)) / (sqDist + EPSILON)
sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
} else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
val dotValue = dot(v1, v2)
sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
(sqDist + EPSILON)
if (precisionBound2 > precision) {
sqDist = breezeSquaredDistance(v1, v2)
// TODO: breezeSquaredDistance is slow,
// so we should replace it with our own implementation.
sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
}
} else {
sqDist = breezeSquaredDistance(v1, v2)
sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
}
sqDist
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
test("fast squared distance") {
val a = (30 to 0 by -1).map(math.pow(2.0, _)).toArray
val n = a.length
val v1 = new BDV[Double](a)
val norm1 = breezeNorm(v1, 2.0)
val v1 = Vectors.dense(a)
val norm1 = Vectors.norm(v1, 2.0)
val precision = 1e-6
for (m <- 0 until n) {
val indices = (0 to m).toArray
val values = indices.map(i => a(i))
val v2 = new BSV[Double](indices, values, n)
val norm2 = breezeNorm(v2, 2.0)
val squaredDist = breezeSquaredDistance(v1, v2)
val v2 = Vectors.sparse(n, indices, values)
val norm2 = Vectors.norm(v2, 2.0)
val squaredDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
val fastSquaredDist1 = fastSquaredDistance(v1, norm1, v2, norm2, precision)
assert((fastSquaredDist1 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
val fastSquaredDist2 = fastSquaredDistance(v1, norm1, v2.toDenseVector, norm2, precision)
val fastSquaredDist2 =
fastSquaredDistance(v1, norm1, Vectors.dense(v2.toArray), norm2, precision)
assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
}
}
Expand Down

0 comments on commit 7fc49ed

Please sign in to comment.