Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed Nov 20, 2014
1 parent 98448bb commit 1907ae1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 130 deletions.
119 changes: 34 additions & 85 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,15 @@ sealed trait Vector extends Serializable {
}

/**
* It will return the iterator for the active elements of dense and sparse vector as
* (index, value) pair. Note that foreach method can be overridden for better performance
* in different vector implementation.
* Applies a function `f` to all the active elements of dense and sparse vector.
*
* @param skippingZeros Skipping zero elements explicitly if true. It will be useful when we
* iterator through dense vector having lots of zero elements which
* we want to skip. Default is false.
* @return Iterator[(Int, Double)] where the first element in the tuple is the index,
* and the second element is the corresponding value.
* @param f the function takes (Int, Double) as input where the first element
* in the tuple is the index, and the second element is the corresponding value.
* @param skippingZeros if true, skipping zero elements explicitly. It will be useful when
* iterating through dense vector which has lots of zero elements to be
* skipped. Default is false.
*/
private[spark] def activeIterator(skippingZeros: Boolean): Iterator[(Int, Double)]

private[spark] def activeIterator: Iterator[(Int, Double)] = activeIterator(false)

private[spark] def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit)
}

/**
Expand Down Expand Up @@ -290,48 +285,25 @@ class DenseVector(val values: Array[Double]) extends Vector {
new DenseVector(values.clone())
}

private[spark] override def activeIterator(skippingZeros: Boolean) = new Iterator[(Int, Double)] {
private var i = 0

// If zeros are asked to be explicitly skipped, the parent `size` method is called to count
// the number of nonzero elements using `hasNext` and `next` methods.
final override lazy val size: Int = if (skippingZeros) super.size else values.size

final override def hasNext = {
if (skippingZeros) {
var found = false
while (!found && i < values.size) if (values(i) != 0.0) found = true else i += 1
}
i < values.size
}

final override def next = {
val result = (i, values(i))
i += 1
result
}
private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
var i = 0
val localValuesSize = values.size
val localValues = values

final override def foreach[@specialized(Unit) U](f: ((Int, Double)) => U) {
var i = 0
val localValuesSize = values.size
val localValues = values

if (skippingZeros) {
while (i < localValuesSize) {
if (localValues(i) != 0.0) {
f(i, localValues(i))
}
i += 1
}
} else {
while (i < localValuesSize) {
if (skippingZeros) {
while (i < localValuesSize) {
if (localValues(i) != 0.0) {
f(i, localValues(i))
i += 1
}
i += 1
}
} else {
while (i < localValuesSize) {
f(i, localValues(i))
i += 1
}
}
}

}

/**
Expand Down Expand Up @@ -369,47 +341,24 @@ class SparseVector(

private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)

private[spark] override def activeIterator(skippingZeros: Boolean) = new Iterator[(Int, Double)] {
private var i = 0

// If zeros are asked to be explicitly skipped, the parent `size` method is called to count
// the number of nonzero elements using `hasNext` and `next` methods.
final override lazy val size: Int = if (skippingZeros) super.size else values.size

final override def hasNext = {
if (skippingZeros) {
var found = false
while (!found && i < values.size) if (values(i) != 0.0) found = true else i += 1
}
i < values.size
}

final override def next = {
val result = (indices(i), values(i))
i += 1
result
}
private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
var i = 0
val localValuesSize = values.size
val localIndices = indices
val localValues = values

final override def foreach[@specialized(Unit) U](f: ((Int, Double)) => U) {
var i = 0
val localValuesSize = values.size
val localIndices = indices
val localValues = values

if (skippingZeros) {
while (i < localValuesSize) {
if (localValues(i) != 0.0) {
f(localIndices(i), localValues(i))
}
i += 1
}
} else {
while (i < localValuesSize) {
if (skippingZeros) {
while (i < localValuesSize) {
if (localValues(i) != 0.0) {
f(localIndices(i), localValues(i))
i += 1
}
i += 1
}
} else {
while (i < localValuesSize) {
f(localIndices(i), localValues(i))
i += 1
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")

sample.activeIterator(true).foreach {
case (index, value) => add(index, value)
}
sample.foreach(true)(x => add(x._1, x._2))

totalCnt += 1
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,47 +174,22 @@ class VectorsSuite extends FunSuite {
assert(v.size === x.rows)
}

test("activeIterator") {
test("foreach") {
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))

// Testing if the size of iterator is correct when the zeros are explicitly skipped.
// The default setting will not skip any zero explicitly.
assert(dv.activeIterator.size === 4)
assert(dv.activeIterator(false).size === 4)
assert(dv.activeIterator(true).size === 2)

assert(sv.activeIterator.size === 3)
assert(sv.activeIterator(false).size === 3)
assert(sv.activeIterator(true).size === 2)

// Testing `hasNext` and `next`
val dvIter1 = dv.activeIterator(false)
assert(dvIter1.hasNext === true && dvIter1.next === (0, 0.0))
assert(dvIter1.hasNext === true && dvIter1.next === (1, 1.2))
assert(dvIter1.hasNext === true && dvIter1.next === (2, 3.1))
assert(dvIter1.hasNext === true && dvIter1.next === (3, 0.0))
assert(dvIter1.hasNext === false)

val dvIter2 = dv.activeIterator(true)
assert(dvIter2.hasNext === true && dvIter2.next === (1, 1.2))
assert(dvIter2.hasNext === true && dvIter2.next === (2, 3.1))
assert(dvIter2.hasNext === false)

val svIter1 = sv.activeIterator(false)
assert(svIter1.hasNext === true && svIter1.next === (1, 1.2))
assert(svIter1.hasNext === true && svIter1.next === (2, 3.1))
assert(svIter1.hasNext === true && svIter1.next === (3, 0.0))
assert(svIter1.hasNext === false)

val svIter2 = sv.activeIterator(true)
assert(svIter2.hasNext === true && svIter2.next === (1, 1.2))
assert(svIter2.hasNext === true && svIter2.next === (2, 3.1))
assert(svIter2.hasNext === false)

// Testing `foreach`
val dvMap0 = scala.collection.mutable.Map[Int, Double]()
dv.foreach() {
case (index: Int, value: Double) => dvMap0.put(index, value)
}
assert(dvMap0.size === 4)
assert(dvMap0.get(0) === Some(0.0))
assert(dvMap0.get(1) === Some(1.2))
assert(dvMap0.get(2) === Some(3.1))
assert(dvMap0.get(3) === Some(0.0))

val dvMap1 = scala.collection.mutable.Map[Int, Double]()
dvIter1.foreach{
dv.foreach(false) {
case (index, value) => dvMap1.put(index, value)
}
assert(dvMap1.size === 4)
Expand All @@ -223,16 +198,25 @@ class VectorsSuite extends FunSuite {
assert(dvMap1.get(2) === Some(3.1))
assert(dvMap1.get(3) === Some(0.0))

val dvMap2 = scala.collection.mutable.Map[Int, Double]()
dvIter2.foreach{
val dvMap2 = scala.collection .mutable.Map[Int, Double]()
dv.foreach(true) {
case (index, value) => dvMap2.put(index, value)
}
assert(dvMap2.size === 2)
assert(dvMap2.get(1) === Some(1.2))
assert(dvMap2.get(2) === Some(3.1))

val svMap0 = scala.collection.mutable.Map[Int, Double]()
sv.foreach() {
case (index, value) => svMap0.put(index, value)
}
assert(svMap0.size === 3)
assert(svMap0.get(1) === Some(1.2))
assert(svMap0.get(2) === Some(3.1))
assert(svMap0.get(3) === Some(0.0))

val svMap1 = scala.collection.mutable.Map[Int, Double]()
svIter1.foreach{
sv.foreach(false) {
case (index, value) => svMap1.put(index, value)
}
assert(svMap1.size === 3)
Expand All @@ -241,12 +225,11 @@ class VectorsSuite extends FunSuite {
assert(svMap1.get(3) === Some(0.0))

val svMap2 = scala.collection.mutable.Map[Int, Double]()
svIter2.foreach{
sv.foreach(true) {
case (index, value) => svMap2.put(index, value)
}
assert(svMap2.size === 2)
assert(svMap2.get(1) === Some(1.2))
assert(svMap2.get(2) === Some(3.1))

}
}

0 comments on commit 1907ae1

Please sign in to comment.