From b98bb18b69f30537fd235ef73d525c0f59f27293 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 12 Aug 2014 23:19:26 -0700 Subject: [PATCH] add comments --- .../correlation/SpearmanCorrelation.scala | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 8b6cd07a42c01..c7efad6b0401f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.stat.correlation -import org.apache.spark.storage.StorageLevel - import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.{Vectors, DenseVector, Matrix, Vector} -import org.apache.spark.rdd.{CoGroupedRDD, RDD} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix @@ -45,18 +44,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { /** * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the * correlation between column i and j. - * - * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into - * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. */ override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { - val transposed = X.zipWithUniqueId().flatMap { case (vec, uid) => + // ((columnIndex, value), rowId) + val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) => vec.toArray.view.zipWithIndex.map { case (v, j) => ((j, v), uid) } - }.persist(StorageLevel.MEMORY_AND_DISK) - val sorted = transposed.sortByKey().persist(StorageLevel.MEMORY_AND_DISK) - val ranked = sorted.zipWithIndex().mapPartitions { iter => + }.persist(StorageLevel.MEMORY_AND_DISK) // used by sortByKey + // global sort by (columnIndex, value) + val sorted = colBased.sortByKey().persist(StorageLevel.MEMORY_AND_DISK) // used by zipWithIndex + // Assign global ranks (using average ranks for tied values) + val globalRanks = sorted.zipWithIndex().mapPartitions { iter => var preCol = -1 var preVal = Double.NaN var startRank = -1.0 @@ -85,14 +84,15 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { flush() } } - val ranks = tied.groupByKey().map { case (uid, iter) => - val values = iter.toSeq.sortBy(_._1).map(_._2).toArray - println(values.toSeq) - Vectors.dense(values) + // Replace values in the input matrix by their ranks compared with values in the same column. + // Note that shifting all ranks in a column by a constant value doesn't affect result. + val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) => + // sort by column index and then convert values to a vector + Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray) } - val corrMatrix = PearsonCorrelation.computeCorrelationMatrix(ranks) + val corrMatrix = PearsonCorrelation.computeCorrelationMatrix(groupedRanks) - transposed.unpersist(blocking = false) + colBased.unpersist(blocking = false) sorted.unpersist(blocking = false) corrMatrix