Skip to content

Commit

Permalink
add parallel mean and variance
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 9af2e95 commit cc65810
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.rdd.RDD
import breeze.numerics._

/**
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
Expand Down Expand Up @@ -161,4 +162,24 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
}
}
}

def parallelMeanAndVar(size: Int): (Vector, Vector) = {
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0))(
seqOp = (c, v) => (c, v) match {
case ((prevMean, prevM2n, cnt), currData) =>
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0)
},
combOp = (lhs, rhs) => (lhs, rhs) match {
case ((lhsMean, lhsM2n, lhsCnt), (rhsMean, rhsM2n, rhsCnt)) =>
val totalCnt = lhsCnt + rhsCnt
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
val deltaMean = rhsMean - lhsMean
val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
(totalMean, totalM2n, totalCnt)
}
)

(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
val colMeans = Array(4.0, 5.0, 6.0)
val colNorm2 = Array(math.sqrt(66.0), math.sqrt(93.0), math.sqrt(126.0))
val colSDs = Array(math.sqrt(6.0), math.sqrt(6.0), math.sqrt(6.0))
val colVar = Array(6.0, 6.0, 6.0)

val maxVec = Array(7.0, 8.0, 9.0)
val minVec = Array(1.0, 2.0, 3.0)
Expand Down Expand Up @@ -128,6 +129,13 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
assert(equivVector(lhs, rhs), "Column shrink error.")
}
}

test("meanAndVar") {
val data = sc.parallelize(localData, 2)
val (mean, sd) = data.parallelMeanAndVar(3)
assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.")
assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.")
}
}

object VectorRDDFunctionsSuite {
Expand Down

0 comments on commit cc65810

Please sign in to comment.