diff --git a/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png b/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png new file mode 100644 index 0000000000000..ed9adad11d03a Binary files /dev/null and b/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png differ diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index c696ae9c8e8c8..413b824e369da 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -34,6 +34,26 @@ a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +### Power Iteration Clustering + +Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm: + +* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. +* calculates the principal eigenvalue and eigenvector +* Clusters each of the input points according to their principal eigenvector component value + +Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf} + +Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation: + +

+ The Property Graph + +

+ ### Examples
diff --git a/mllib/pom.xml b/mllib/pom.xml index fc2b2cc09c717..a8cee3d51a780 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -50,6 +50,11 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.version} + org.jblas jblas diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..fcb9a3643cc48 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Model produced by [[PowerIterationClustering]]. + * + * @param k number of clusters + * @param assignments an RDD of (vertexID, clusterID) pairs + */ +class PowerIterationClusteringModel( + val k: Int, + val assignments: RDD[(Long, Int)]) extends Serializable + +/** + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and + * Cohen (see http://www.icml2010.org/papers/387.pdf). From the abstract: PIC finds a very + * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise + * similarity matrix of the data. + * + * @param k Number of clusters. + * @param maxIterations Maximum number of iterations of the PIC algorithm. + */ +class PowerIterationClustering private[clustering] ( + private var k: Int, + private var maxIterations: Int) extends Serializable { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100}. */ + def this() = this(k = 2, maxIterations = 100) + + /** + * Set the number of clusters. + */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** + * Set maximum number of iterations of the power iteration loop + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** + * Run the PIC algorithm. + * + * @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is + * the matrix A in the PIC paper. The similarity s_ij_ must be nonnegative. + * This is a symmetric matrix and hence s_ij_ = s_ji_. For any (i, j) with + * nonzero similarity, there should be either (i, j, s_ij_) or (j, i, s_ji_) + * in the input. Tuples with i = j are ignored, because we assume s_ij_ = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { + val w = normalize(similarities) + val w0 = randomInit(w) + pic(w0) + } + + /** + * Runs the PIC algorithm. + * + * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with + * w_ij_ = a_ij_ / d_ii_ as its edge properties and the initial vector of the power + * iteration as its vertex properties. + */ + private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { + val v = powerIter(w, maxIterations) + val assignments = kMeans(v, k) + new PowerIterationClusteringModel(k, assignments) + } +} + +private[clustering] object PowerIterationClustering extends Logging { + /** + * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). + */ + def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = { + val edges = similarities.flatMap { case (i, j, s) => + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val gA = Graph.fromEdges(edges, 0.0) + val vD = gA.aggregateMessages[Double]( + sendMsg = ctx => { + ctx.sendToSrc(ctx.attr) + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, gA.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + + /** + * Generates random vertex properties (v0) to start power iteration. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing a random vector + * with unit 1-norm + */ + def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { + val r = g.vertices.mapPartitionsWithIndex( + (part, iter) => { + val random = new XORShiftRandom(part) + iter.map { case (id, _) => + (id, random.nextGaussian()) + } + }, preservesPartitioning = true).cache() + val sum = r.values.map(math.abs).sum() + val v0 = r.mapValues(x => x / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + + /** + * Runs power iteration. + * @param g input graph with edges representing the normalized affinity matrix (W) and vertices + * representing the initial vector of the power iterations. + * @param maxIterations maximum number of iterations + * @return a [[VertexRDD]] representing the pseudo-eigenvector + */ + def powerIter( + g: Graph[Double, Double], + maxIterations: Int): VertexRDD[Double] = { + // the default tolerance used in the PIC paper, with a lower bound 1e-8 + val tol = math.max(1e-5 / g.vertices.count(), 1e-8) + var prevDelta = Double.MaxValue + var diffDelta = Double.MaxValue + var curG = g + for (iter <- 0 until maxIterations if math.abs(diffDelta) > tol) { + val msgPrefix = s"Iteration $iter" + // multiply W by vt + val v = curG.aggregateMessages[Double]( + sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr), + mergeMsg = _ + _, + TripletFields.Dst).cache() + // normalize v + val norm = v.values.map(math.abs).sum() + logInfo(s"$msgPrefix: norm(v) = $norm.") + val v1 = v.mapValues(x => x / norm) + // compare difference + val delta = curG.joinVertices(v1) { case (_, x, y) => + math.abs(x - y) + }.vertices.values.sum() + logInfo(s"$msgPrefix: delta = $delta.") + diffDelta = math.abs(delta - prevDelta) + logInfo(s"$msgPrefix: diff(delta) = $diffDelta.") + // update v + curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges) + prevDelta = delta + } + curG.vertices + } + + /** + * Runs k-means clustering. + * @param v a [[VertexRDD]] representing the pseudo-eigenvector + * @param k number of clusters + * @return a [[VertexRDD]] representing the clustering assignments + */ + def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { + val points = v.mapValues(x => Vectors.dense(x)).cache() + val model = new KMeans() + .setK(k) + .setRuns(5) + .setSeed(0L) + .run(points.values) + points.mapValues(p => model.predict(p)).cache() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 693419f827379..a6405975ebe2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -21,8 +21,8 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrix} +import org.apache.spark.{SparkException, Logging, Partitioner} +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -158,11 +158,13 @@ class BlockMatrix( private[mllib] var partitioner: GridPartitioner = GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size) + private lazy val blockInfo = blocks.mapValues(block => (block.numRows, block.numCols)).cache() + /** Estimates the dimensions of the matrix. */ private def estimateDim(): Unit = { - val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) => - (blockRowIndex.toLong * rowsPerBlock + mat.numRows, - blockColIndex.toLong * colsPerBlock + mat.numCols) + val (rows, cols) = blockInfo.map { case ((blockRowIndex, blockColIndex), (m, n)) => + (blockRowIndex.toLong * rowsPerBlock + m, + blockColIndex.toLong * colsPerBlock + n) }.reduce { (x0, x1) => (math.max(x0._1, x1._1), math.max(x0._2, x1._2)) } @@ -172,6 +174,41 @@ class BlockMatrix( assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") } + def validate(): Unit = { + logDebug("Validating BlockMatrix...") + // check if the matrix is larger than the claimed dimensions + estimateDim() + logDebug("BlockMatrix dimensions are okay...") + + // Check if there are multiple MatrixBlocks with the same index. + blockInfo.countByKey().foreach { case (key, cnt) => + if (cnt > 1) { + throw new SparkException(s"Found multiple MatrixBlocks with the indices $key. Please " + + "remove blocks with duplicate indices.") + } + } + logDebug("MatrixBlock indices are okay...") + // Check if each MatrixBlock (except edges) has the dimensions rowsPerBlock x colsPerBlock + // The first tuple is the index and the second tuple is the dimensions of the MatrixBlock + val dimensionMsg = s"dimensions different than rowsPerBlock: $rowsPerBlock, and " + + s"colsPerBlock: $colsPerBlock. Blocks on the right and bottom edges can have smaller " + + s"dimensions. You may use the repartition method to fix this issue." + blockInfo.foreach { case ((blockRowIndex, blockColIndex), (m, n)) => + if ((blockRowIndex < numRowBlocks - 1 && m != rowsPerBlock) || + (blockRowIndex == numRowBlocks - 1 && (m <= 0 || m > rowsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + if ((blockColIndex < numColBlocks - 1 && n != colsPerBlock) || + (blockColIndex == numColBlocks - 1 && (n <= 0 || n > colsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + } + logDebug("MatrixBlock dimensions are okay...") + logDebug("BlockMatrix is valid!") + } + /** Caches the underlying RDD. */ def cache(): this.type = { blocks.cache() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..2bae465d392aa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + test("power iteration clustering") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + val model = new PowerIterationClustering() + .setK(2) + .run(sc.parallelize(similarities, 2)) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { case (i, c) => + predictions(c) += i + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + + test("normalize and powerIter") { + /* + Test normalize() with the following graph: + + 0 - 3 + | \ | + 1 - 2 + + The affinity matrix (A) is + + 0 1 1 1 + 1 0 1 0 + 1 1 0 1 + 1 0 1 0 + + D is diag(3, 2, 3, 2) and hence W is + + 0 1/3 1/3 1/3 + 1/2 0 1/2 0 + 1/3 1/3 0 1/3 + 1/2 0 1/2 0 + */ + val similarities = Seq[(Long, Long, Double)]( + (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + val expected = Array( + Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), + Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + val w = normalize(sc.parallelize(similarities, 2)) + w.edges.collect().foreach { case Edge(i, j, x) => + assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) + } + val v0 = sc.parallelize(Seq[(Long, Double)]((0, 0.1), (1, 0.2), (2, 0.3), (3, 0.4)), 2) + val w0 = Graph(v0, w.edges) + val v1 = powerIter(w0, maxIterations = 1).collect() + val u = Array(0.3, 0.2, 0.7/3.0, 0.2) + val norm = u.sum + val u1 = u.map(x => x / norm) + v1.foreach { case (i, x) => + assert(x ~== u1(i.toInt) absTol 1e-14) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 03f34308dd09b..461f1f92df1d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM} import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -147,6 +148,47 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(gridBasedMat.toBreeze() === expected) } + test("validate") { + // No error + gridBasedMat.validate() + // Wrong MatrixBlock dimensions + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val wrongRowPerParts = new BlockMatrix(rdd, rowPerPart + 1, colPerPart) + val wrongColPerParts = new BlockMatrix(rdd, rowPerPart, colPerPart + 1) + intercept[SparkException] { + wrongRowPerParts.validate() + } + intercept[SparkException] { + wrongColPerParts.validate() + } + // Wrong BlockMatrix dimensions + val wrongRowSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 4, 4) + intercept[AssertionError] { + wrongRowSize.validate() + } + val wrongColSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 5, 2) + intercept[AssertionError] { + wrongColSize.validate() + } + // Duplicate indices + val duplicateBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 0), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 1), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val dupMatrix = new BlockMatrix(sc.parallelize(duplicateBlocks, numPartitions), 2, 2) + intercept[SparkException] { + dupMatrix.validate() + } + } + test("transpose") { val expected = BDM( (1.0, 0.0, 3.0, 0.0, 0.0),