Skip to content

Commit

Permalink
futher performance tunning.
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed Nov 21, 2014
1 parent 1907ae1 commit 03dd693
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 143 deletions.
44 changes: 12 additions & 32 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,11 @@ sealed trait Vector extends Serializable {
/**
* Applies a function `f` to all the active elements of dense and sparse vector.
*
* @param f the function takes (Int, Double) as input where the first element
* in the tuple is the index, and the second element is the corresponding value.
* @param skippingZeros if true, skipping zero elements explicitly. It will be useful when
* iterating through dense vector which has lots of zero elements to be
* skipped. Default is false.
* @param f the function takes two parameters where the first parameter is the index of
* the vector with type `Int`, and the second parameter is the corresponding value
* with type `Double`.
*/
private[spark] def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit)
private[spark] def foreachActive(f: (Int, Double) => Unit)
}

/**
Expand Down Expand Up @@ -285,23 +283,14 @@ class DenseVector(val values: Array[Double]) extends Vector {
new DenseVector(values.clone())
}

private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
val localValuesSize = values.size
val localValues = values

if (skippingZeros) {
while (i < localValuesSize) {
if (localValues(i) != 0.0) {
f(i, localValues(i))
}
i += 1
}
} else {
while (i < localValuesSize) {
f(i, localValues(i))
i += 1
}
while (i < localValuesSize) {
f(i, localValues(i))
i += 1
}
}
}
Expand Down Expand Up @@ -341,24 +330,15 @@ class SparseVector(

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)

private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
val localValuesSize = values.size
val localIndices = indices
val localValues = values

if (skippingZeros) {
while (i < localValuesSize) {
if (localValues(i) != 0.0) {
f(localIndices(i), localValues(i))
}
i += 1
}
} else {
while (i < localValuesSize) {
f(localIndices(i), localValues(i))
i += 1
}
while (i < localValuesSize) {
f(localIndices(i), localValues(i))
i += 1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

package org.apache.spark.mllib.stat

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
import org.apache.spark.mllib.linalg.{Vectors, Vector}

/**
* :: DeveloperApi ::
Expand All @@ -40,35 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {

private var n = 0
private var currMean: BDV[Double] = _
private var currM2n: BDV[Double] = _
private var currM2: BDV[Double] = _
private var currL1: BDV[Double] = _
private var currMean: Array[Double] = _
private var currM2n: Array[Double] = _
private var currM2: Array[Double] = _
private var currL1: Array[Double] = _
private var totalCnt: Long = 0
private var nnz: BDV[Double] = _
private var currMax: BDV[Double] = _
private var currMin: BDV[Double] = _

/**
* Adds input value to position i.
*/
private[this] def add(i: Int, value: Double) = {
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val prevMean = currMean(i)
val diff = value - prevMean
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * diff
currM2(i) += value * value
currL1(i) += math.abs(value)

nnz(i) += 1.0
}
private var nnz: Array[Double] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _

/**
* Add a new sample to this summarizer, and update the statistical summary.
Expand All @@ -81,19 +58,37 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size

currMean = BDV.zeros[Double](n)
currM2n = BDV.zeros[Double](n)
currM2 = BDV.zeros[Double](n)
currL1 = BDV.zeros[Double](n)
nnz = BDV.zeros[Double](n)
currMax = BDV.fill(n)(Double.MinValue)
currMin = BDV.fill(n)(Double.MaxValue)
currMean = Array.ofDim[Double](n)
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
nnz = Array.ofDim[Double](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
}

require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")

sample.foreach(true)(x => add(x._1, x._2))
sample.foreachActive((index, value) => {
if(value != 0.0){
if (currMax(index) < value) {
currMax(index) = value
}
if (currMin(index) > value) {
currMin(index) = value
}

val prevMean = currMean(index)
val diff = value - prevMean
currMean(index) = prevMean + diff / (nnz(index) + 1.0)
currM2n(index) += (value - currMean(index)) * diff
currM2(index) += value * value
currL1(index) += math.abs(value)

nnz(index) += 1.0
}
})

totalCnt += 1
this
Expand Down Expand Up @@ -135,34 +130,34 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy
this.currM2n = other.currM2n.copy
this.currM2 = other.currM2.copy
this.currL1 = other.currL1.copy
this.currMean = other.currMean.clone
this.currM2n = other.currM2n.clone
this.currM2 = other.currM2.clone
this.currL1 = other.currL1.clone
this.totalCnt = other.totalCnt
this.nnz = other.nnz.copy
this.currMax = other.currMax.copy
this.currMin = other.currMin.copy
this.nnz = other.nnz.clone
this.currMax = other.currMax.clone
this.currMin = other.currMin.clone
}
this
}

override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realMean = BDV.zeros[Double](n)
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1
}
Vectors.fromBreeze(realMean)
Vectors.dense(realMean)
}

override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realVariance = BDV.zeros[Double](n)
val realVariance = Array.ofDim[Double](n)

val denominator = totalCnt - 1.0

Expand All @@ -177,16 +172,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
i += 1
}
}

Vectors.fromBreeze(realVariance)
Vectors.dense(realVariance)
}

override def count: Long = totalCnt

override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

Vectors.fromBreeze(nnz)
Vectors.dense(nnz)
}

override def max: Vector = {
Expand All @@ -197,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
Vectors.dense(currMax)
}

override def min: Vector = {
Expand All @@ -208,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
Vectors.dense(currMin)
}

override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")

val realMagnitude = BDV.zeros[Double](n)
val realMagnitude = Array.ofDim[Double](n)

var i = 0
while (i < currM2.size) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}

Vectors.fromBreeze(realMagnitude)
Vectors.dense(realMagnitude)
}

override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.fromBreeze(currL1)

Vectors.dense(currL1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,58 +178,19 @@ class VectorsSuite extends FunSuite {
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))

val dvMap0 = scala.collection.mutable.Map[Int, Double]()
dv.foreach() {
case (index: Int, value: Double) => dvMap0.put(index, value)
}
assert(dvMap0.size === 4)
assert(dvMap0.get(0) === Some(0.0))
assert(dvMap0.get(1) === Some(1.2))
assert(dvMap0.get(2) === Some(3.1))
assert(dvMap0.get(3) === Some(0.0))

val dvMap1 = scala.collection.mutable.Map[Int, Double]()
dv.foreach(false) {
case (index, value) => dvMap1.put(index, value)
}
assert(dvMap1.size === 4)
assert(dvMap1.get(0) === Some(0.0))
assert(dvMap1.get(1) === Some(1.2))
assert(dvMap1.get(2) === Some(3.1))
assert(dvMap1.get(3) === Some(0.0))

val dvMap2 = scala.collection .mutable.Map[Int, Double]()
dv.foreach(true) {
case (index, value) => dvMap2.put(index, value)
}
assert(dvMap2.size === 2)
assert(dvMap2.get(1) === Some(1.2))
assert(dvMap2.get(2) === Some(3.1))

val svMap0 = scala.collection.mutable.Map[Int, Double]()
sv.foreach() {
case (index, value) => svMap0.put(index, value)
}
assert(svMap0.size === 3)
assert(svMap0.get(1) === Some(1.2))
assert(svMap0.get(2) === Some(3.1))
assert(svMap0.get(3) === Some(0.0))

val svMap1 = scala.collection.mutable.Map[Int, Double]()
sv.foreach(false) {
case (index, value) => svMap1.put(index, value)
}
assert(svMap1.size === 3)
assert(svMap1.get(1) === Some(1.2))
assert(svMap1.get(2) === Some(3.1))
assert(svMap1.get(3) === Some(0.0))

val svMap2 = scala.collection.mutable.Map[Int, Double]()
sv.foreach(true) {
case (index, value) => svMap2.put(index, value)
}
assert(svMap2.size === 2)
assert(svMap2.get(1) === Some(1.2))
assert(svMap2.get(2) === Some(3.1))
val dvMap = scala.collection.mutable.Map[Int, Double]()
dv.foreachActive((index, value) => dvMap.put(index, value))
assert(dvMap.size === 4)
assert(dvMap.get(0) === Some(0.0))
assert(dvMap.get(1) === Some(1.2))
assert(dvMap.get(2) === Some(3.1))
assert(dvMap.get(3) === Some(0.0))

val svMap = scala.collection.mutable.Map[Int, Double]()
sv.foreachActive((index, value) => svMap.put(index, value))
assert(svMap.size === 3)
assert(svMap.get(1) === Some(1.2))
assert(svMap.get(2) === Some(3.1))
assert(svMap.get(3) === Some(0.0))
}
}

0 comments on commit 03dd693

Please sign in to comment.