Skip to content

Commit

Permalink
clean up code for distance computation
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 20, 2014
1 parent 712cb88 commit 72bde33
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 152 deletions.
126 changes: 34 additions & 92 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,30 @@ package org.apache.spark.mllib.clustering

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{DenseVector => BDV, Vector => BV, squaredDistance => breezeSquaredDistance,
norm => breezeNorm}

import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm}
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom

/**
* A breeze vector with its squared norm for fast distance computation.
* See [[org.apache.spark.mllib.clustering.KMeans.fastSquaredDistance()]].
* A breeze vector with its norm for fast distance computation.
*
* @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
*/
private[clustering]
class BreezeVectorWithSquaredNorm(val vector: BV[Double], val squaredNorm: Double)
extends Serializable {
def this(vector: BV[Double]) = {
this(vector, {val nrm = breezeNorm(vector, 2.0); nrm * nrm})
}
class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {

def this(vector: BV[Double]) = this(vector, breezeNorm(vector, 2.0))

def this(array: Array[Double]) = this(new BDV[Double](array))

def this(v: Vector) = this(v.toBreeze)

/** Converts the vector to a dense vector. */
def toDense = new BreezeVectorWithSquaredNorm(vector.toDenseVector, squaredNorm)
def toDense = new BreezeVectorWithNorm(vector.toDenseVector, norm)
}

/**
Expand Down Expand Up @@ -135,23 +137,20 @@ class KMeans private (
*/
def run(data: RDD[Vector])(implicit d: DummyImplicit): KMeansModel = {
// Compute squared norms and cache them.
val squaredNorms = data.map { v =>
val nrm = breezeNorm(v.toBreeze, 2.0)
nrm * nrm
}
squaredNorms.persist()
val breezeData = data.map(_.toBreeze).zip(squaredNorms).map { case (v, squaredNorm) =>
new BreezeVectorWithSquaredNorm(v, squaredNorm)
val norms = data.map(v => breezeNorm(v.toBreeze, 2.0))
norms.persist()
val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
new BreezeVectorWithNorm(v, norm)
}
val model = runBreeze(breezeData)
squaredNorms.unpersist()
norms.unpersist()
model
}

/**
* Implementation of K-Means using breeze.
*/
private def runBreeze(data: RDD[BreezeVectorWithSquaredNorm]): KMeansModel = {
private def runBreeze(data: RDD[BreezeVectorWithNorm]): KMeansModel = {

val sc = data.sparkContext

Expand Down Expand Up @@ -217,7 +216,7 @@ class KMeans private (
val (sum, count) = totalContribs((i, j))
if (count != 0) {
sum /= count.toDouble
val newCenter = new BreezeVectorWithSquaredNorm(sum)
val newCenter = new BreezeVectorWithNorm(sum)
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
changed = true
}
Expand Down Expand Up @@ -257,12 +256,12 @@ class KMeans private (
/**
* Initialize `runs` sets of cluster centers at random.
*/
private def initRandom(data: RDD[BreezeVectorWithSquaredNorm])
: Array[Array[BreezeVectorWithSquaredNorm]] = {
private def initRandom(data: RDD[BreezeVectorWithNorm])
: Array[Array[BreezeVectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new BreezeVectorWithSquaredNorm(v.vector.toDenseVector, v.squaredNorm)
new BreezeVectorWithNorm(v.vector.toDenseVector, v.norm)
}.toArray)
}

Expand All @@ -275,8 +274,8 @@ class KMeans private (
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
private def initKMeansParallel(data: RDD[BreezeVectorWithSquaredNorm])
: Array[Array[BreezeVectorWithSquaredNorm]] = {
private def initKMeansParallel(data: RDD[BreezeVectorWithNorm])
: Array[Array[BreezeVectorWithNorm]] = {
// Initialize each run's center to a random point
val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
Expand Down Expand Up @@ -380,56 +379,19 @@ object KMeans {
train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
}

/**
* Return the index of the closest point in `centers` to `point`, as well as its distance.
*/
private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double])
: (Int, Double) =
{
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
for (i <- 0 until centers.length) {
val distance = MLUtils.squaredDistance(point, centers(i))
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
}
}
(bestIndex, bestDistance)
}

/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
private[mllib] def findClosest(centers: TraversableOnce[BV[Double]], point: BV[Double])
: (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
centers.foreach { v =>
val distance: Double = breezeSquaredDistance(v, point)
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
}
i += 1
}
(bestIndex, bestDistance)
}

/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
private[mllib] def findClosest(
centers: TraversableOnce[BreezeVectorWithSquaredNorm],
point: BreezeVectorWithSquaredNorm): (Int, Double) = {
centers: TraversableOnce[BreezeVectorWithNorm],
point: BreezeVectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
centers.foreach { center =>
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
// distance computation.
var lowerBoundOfSqDist = math.sqrt(center.squaredNorm) - math.sqrt(point.squaredNorm)
var lowerBoundOfSqDist = center.norm - point.norm
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
if (lowerBoundOfSqDist < bestDistance) {
val distance: Double = fastSquaredDistance(center, point)
Expand All @@ -443,42 +405,22 @@ object KMeans {
(bestIndex, bestDistance)
}

/**
* Return the K-means cost of a given point against the given cluster centers.
*/
private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = {
var bestDistance = Double.PositiveInfinity
for (i <- 0 until centers.length) {
val distance = MLUtils.squaredDistance(point, centers(i))
if (distance < bestDistance) {
bestDistance = distance
}
}
bestDistance
}

/**
* Returns the K-means cost of a given point against the given cluster centers.
*/
private[mllib] def pointCost(centers: TraversableOnce[BV[Double]], point: BV[Double]): Double =
findClosest(centers, point)._2

/**
* Returns the K-means cost of a given point against the given cluster centers.
*/
private[mllib] def pointCost(
centers: TraversableOnce[BreezeVectorWithSquaredNorm],
point: BreezeVectorWithSquaredNorm): Double =
centers: TraversableOnce[BreezeVectorWithNorm],
point: BreezeVectorWithNorm): Double =
findClosest(centers, point)._2

/**
* Returns the squared Euclidean distance between two vectors computed by
* [[org.apache.spark.mllib.util.MLUtils.fastSquaredDistance()]].
* [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
*/
private[clustering]
def fastSquaredDistance(v1: BreezeVectorWithSquaredNorm, v2: BreezeVectorWithSquaredNorm)
def fastSquaredDistance(v1: BreezeVectorWithNorm, v2: BreezeVectorWithNorm)
: Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.squaredNorm, v2.vector, v2.squaredNorm)
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}

def main(args: Array[String]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.mllib.clustering

import breeze.linalg.{DenseVector => BreezeDenseVector}

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
Expand All @@ -33,35 +31,38 @@ class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable

/** Return the cluster index that a given point belongs to. */
def predict(point: Array[Double]): Int = {
KMeans.findClosest(clusterCenters, point)._1
KMeans.findClosest(clusterCentersWithNorm, new BreezeVectorWithNorm(point))._1
}

/** Returns the cluster index that a given point belongs to. */
def predict(point: Vector): Int = {
val breezeClusterCenters = clusterCenters.view.map(new BreezeDenseVector[Double](_))
KMeans.findClosest(breezeClusterCenters, point.toBreeze)._1
KMeans.findClosest(clusterCentersWithNorm, new BreezeVectorWithNorm(point))._1
}

/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val breezeClusterCenters = clusterCenters.map(new BreezeDenseVector[Double](_))
points.map(p => KMeans.findClosest(breezeClusterCenters, p.toBreeze)._1)
val centersWithNorm = clusterCentersWithNorm
points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
}

/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
*/
def computeCost(data: RDD[Array[Double]]): Double = {
data.map(p => KMeans.pointCost(clusterCenters, p)).sum()
val centersWithNorm = clusterCentersWithNorm
data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
}

/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
*/
def computeCost(data: RDD[Vector])(implicit d: DummyImplicit): Double = {
val breezeClusterCenters = clusterCenters.map(new BreezeDenseVector[Double](_))
data.map(p => KMeans.pointCost(breezeClusterCenters, p.toBreeze)).sum()
val centersWithNorm = clusterCentersWithNorm
data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
}

private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
clusterCenters.map(new BreezeVectorWithNorm(_))
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,20 @@ import org.apache.spark.Logging
*/
private[mllib] object LocalKMeans extends Logging {

def kMeansPlusPlus(
seed: Int,
points: Array[Array[Double]],
weights: Array[Double],
k: Int,
maxIterations: Int
): Array[Array[Double]] = {
val breezePoints = points.map { v =>
val bv = new BDV[Double](v)
val norm: Double = breezeNorm(bv, 2.0)
new BreezeVectorWithSquaredNorm(bv, norm * norm)
}
val breezeCenters = kMeansPlusPlus(seed, breezePoints, weights, k, maxIterations)
breezeCenters.map(_.vector.toArray)
}

/**
* Run K-means++ on the weighted point set `points`. This first does the K-means++
* initialization procedure and then rounds of Lloyd's algorithm.
*/
def kMeansPlusPlus(
seed: Int,
points: Array[BreezeVectorWithSquaredNorm],
points: Array[BreezeVectorWithNorm],
weights: Array[Double],
k: Int,
maxIterations: Int
)(implicit d: DummyImplicit): Array[BreezeVectorWithSquaredNorm] = {
): Array[BreezeVectorWithNorm] = {
val rand = new Random(seed)
val dimensions = points(0).vector.length
val centers = new Array[BreezeVectorWithSquaredNorm](k)
val centers = new Array[BreezeVectorWithNorm](k)

// Initialize centers by sampling using the k-means++ procedure.
centers(0) = pickWeighted(rand, points, weights).toDense
Expand Down Expand Up @@ -108,7 +92,7 @@ private[mllib] object LocalKMeans extends Logging {
centers(j) = points(rand.nextInt(points.length)).toDense
} else {
sums(j) /= counts(j)
centers(j) = new BreezeVectorWithSquaredNorm(sums(j))
centers(j) = new BreezeVectorWithNorm(sums(j))
}
j += 1
}
Expand Down
30 changes: 7 additions & 23 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,6 @@ object MLUtils {
(yMean, xColMean, xColSd)
}

/**
* Return the squared Euclidean distance between two vectors.
*/
def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = {
if (v1.length != v2.length) {
throw new IllegalArgumentException("Vector sizes don't match")
}
var i = 0
var sum = 0.0
while (i < v1.length) {
sum += (v1(i) - v2(i)) * (v1(i) - v2(i))
i += 1
}
sum
}

/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
Expand All @@ -136,23 +120,23 @@ object MLUtils {
* especially when one of the vectors is a sparse vector.
*
* @param v1 the first vector
* @param squaredNorm1 the squared norm of the first vector, non-negative
* @param norm1 the norm of the first vector, non-negative
* @param v2 the second vector
* @param squaredNorm2 the squared norm of the second vector, non-negative
* @param norm2 the norm of the second vector, non-negative
* @param precision desired relative precision for the squared distance
* @return squared distance between v1 and v2 within the specified precision
*/
private[mllib] def fastSquaredDistance(
v1: BV[Double],
squaredNorm1: Double,
norm1: Double,
v2: BV[Double],
squaredNorm2: Double,
norm2: Double,
precision: Double = 1e-6): Double = {
val n = v1.size
require(v2.size == n)
require(squaredNorm1 >= 0.0 && squaredNorm2 >= 0.0)
val sumSquaredNorm = squaredNorm1 + squaredNorm2
val normDiff = math.sqrt(squaredNorm1) - math.sqrt(squaredNorm2)
require(norm1 >= 0.0 && norm2 >= 0.0)
val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
val normDiff = norm1 - norm2
var sqDist = 0.0
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.util

import org.scalatest.Suite
Expand Down
Loading

0 comments on commit 72bde33

Please sign in to comment.