Skip to content

Commit

Permalink
use axpy and in-place if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2014
1 parent 9a75ebd commit 62a2c3e
Showing 1 changed file with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import breeze.linalg.{Vector => BV}

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import breeze.linalg.axpy

case class VectorRDDStatisticalSummary(
mean: Vector,
Expand Down Expand Up @@ -58,17 +59,22 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
BV.fill(size){Double.MaxValue}))(
seqOp = (c, v) => (c, v) match {
case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
val nonZeroCnt = Vectors
.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
val currMean = prevMean :* (cnt / (cnt + 1.0))
axpy(1.0/(cnt+1.0), currData, currMean)
axpy(-1.0, currData, prevMean)
prevMean :*= (currMean - currData)
axpy(1.0, prevMean, prevM2n)
axpy(1.0,
Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze,
nnzVec)
currData.activeIterator.foreach { case (id, value) =>
if (maxVec(id) < value) maxVec(id) = value
if (minVec(id) > value) minVec(id) = value
}
(currMean,
prevM2n + ((currData - prevMean) :* (currData - currMean)),
prevM2n,
cnt + 1.0,
nnzVec + nonZeroCnt,
nnzVec,
maxVec,
minVec)
},
Expand All @@ -77,23 +83,30 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
(lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin),
(rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
val totalCnt = lhsCnt + rhsCnt
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
val deltaMean = rhsMean - lhsMean
val totalM2n =
lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
lhsMean :*= (lhsCnt / totalCnt)
axpy(rhsCnt/totalCnt, rhsMean, lhsMean)
val totalMean = lhsMean
deltaMean :*= deltaMean
axpy(lhsCnt*rhsCnt/totalCnt, deltaMean, lhsM2n)
axpy(1.0, rhsM2n, lhsM2n)
val totalM2n = lhsM2n
rhsMax.activeIterator.foreach { case (id, value) =>
if (lhsMax(id) < value) lhsMax(id) = value
}
rhsMin.activeIterator.foreach { case (id, value) =>
if (lhsMin(id) > value) lhsMin(id) = value
}
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
axpy(1.0, rhsNNZ, lhsNNZ)
(totalMean, totalM2n, totalCnt, lhsNNZ, lhsMax, lhsMin)
}
)

results._2 :/= results._3

VectorRDDStatisticalSummary(
Vectors.fromBreeze(results._1),
Vectors.fromBreeze(results._2 :/ results._3),
Vectors.fromBreeze(results._2),
results._3.toLong,
Vectors.fromBreeze(results._4),
Vectors.fromBreeze(results._5),
Expand Down

0 comments on commit 62a2c3e

Please sign in to comment.