From 1907ae122ac0f385e5c408b827bd438e209cd71e Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 20 Nov 2014 14:49:27 -0800 Subject: [PATCH] address feedback --- .../apache/spark/mllib/linalg/Vectors.scala | 119 +++++------------- .../stat/MultivariateOnlineSummarizer.scala | 4 +- .../spark/mllib/linalg/VectorsSuite.scala | 67 ++++------ 3 files changed, 60 insertions(+), 130 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index a7d2624a817f2..f340f6feae30e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -78,20 +78,15 @@ sealed trait Vector extends Serializable { } /** - * It will return the iterator for the active elements of dense and sparse vector as - * (index, value) pair. Note that foreach method can be overridden for better performance - * in different vector implementation. + * Applies a function `f` to all the active elements of dense and sparse vector. * - * @param skippingZeros Skipping zero elements explicitly if true. It will be useful when we - * iterator through dense vector having lots of zero elements which - * we want to skip. Default is false. - * @return Iterator[(Int, Double)] where the first element in the tuple is the index, - * and the second element is the corresponding value. + * @param f the function takes (Int, Double) as input where the first element + * in the tuple is the index, and the second element is the corresponding value. + * @param skippingZeros if true, skipping zero elements explicitly. It will be useful when + * iterating through dense vector which has lots of zero elements to be + * skipped. Default is false. */ - private[spark] def activeIterator(skippingZeros: Boolean): Iterator[(Int, Double)] - - private[spark] def activeIterator: Iterator[(Int, Double)] = activeIterator(false) - + private[spark] def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) } /** @@ -290,48 +285,25 @@ class DenseVector(val values: Array[Double]) extends Vector { new DenseVector(values.clone()) } - private[spark] override def activeIterator(skippingZeros: Boolean) = new Iterator[(Int, Double)] { - private var i = 0 - - // If zeros are asked to be explicitly skipped, the parent `size` method is called to count - // the number of nonzero elements using `hasNext` and `next` methods. - final override lazy val size: Int = if (skippingZeros) super.size else values.size - - final override def hasNext = { - if (skippingZeros) { - var found = false - while (!found && i < values.size) if (values(i) != 0.0) found = true else i += 1 - } - i < values.size - } - - final override def next = { - val result = (i, values(i)) - i += 1 - result - } + private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) { + var i = 0 + val localValuesSize = values.size + val localValues = values - final override def foreach[@specialized(Unit) U](f: ((Int, Double)) => U) { - var i = 0 - val localValuesSize = values.size - val localValues = values - - if (skippingZeros) { - while (i < localValuesSize) { - if (localValues(i) != 0.0) { - f(i, localValues(i)) - } - i += 1 - } - } else { - while (i < localValuesSize) { + if (skippingZeros) { + while (i < localValuesSize) { + if (localValues(i) != 0.0) { f(i, localValues(i)) - i += 1 } + i += 1 + } + } else { + while (i < localValuesSize) { + f(i, localValues(i)) + i += 1 } } } - } /** @@ -369,47 +341,24 @@ class SparseVector( private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) - private[spark] override def activeIterator(skippingZeros: Boolean) = new Iterator[(Int, Double)] { - private var i = 0 - - // If zeros are asked to be explicitly skipped, the parent `size` method is called to count - // the number of nonzero elements using `hasNext` and `next` methods. - final override lazy val size: Int = if (skippingZeros) super.size else values.size - - final override def hasNext = { - if (skippingZeros) { - var found = false - while (!found && i < values.size) if (values(i) != 0.0) found = true else i += 1 - } - i < values.size - } - - final override def next = { - val result = (indices(i), values(i)) - i += 1 - result - } + private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) { + var i = 0 + val localValuesSize = values.size + val localIndices = indices + val localValues = values - final override def foreach[@specialized(Unit) U](f: ((Int, Double)) => U) { - var i = 0 - val localValuesSize = values.size - val localIndices = indices - val localValues = values - - if (skippingZeros) { - while (i < localValuesSize) { - if (localValues(i) != 0.0) { - f(localIndices(i), localValues(i)) - } - i += 1 - } - } else { - while (i < localValuesSize) { + if (skippingZeros) { + while (i < localValuesSize) { + if (localValues(i) != 0.0) { f(localIndices(i), localValues(i)) - i += 1 } + i += 1 + } + } else { + while (i < localValuesSize) { + f(localIndices(i), localValues(i)) + i += 1 } } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 55f93bc1b52f4..417564b597914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -93,9 +93,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample.activeIterator(true).foreach { - case (index, value) => add(index, value) - } + sample.foreach(true)(x => add(x._1, x._2)) totalCnt += 1 this diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 2d7c13b7acbb4..81e886b4bd4ee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -174,47 +174,22 @@ class VectorsSuite extends FunSuite { assert(v.size === x.rows) } - test("activeIterator") { + test("foreach") { val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) - // Testing if the size of iterator is correct when the zeros are explicitly skipped. - // The default setting will not skip any zero explicitly. - assert(dv.activeIterator.size === 4) - assert(dv.activeIterator(false).size === 4) - assert(dv.activeIterator(true).size === 2) - - assert(sv.activeIterator.size === 3) - assert(sv.activeIterator(false).size === 3) - assert(sv.activeIterator(true).size === 2) - - // Testing `hasNext` and `next` - val dvIter1 = dv.activeIterator(false) - assert(dvIter1.hasNext === true && dvIter1.next === (0, 0.0)) - assert(dvIter1.hasNext === true && dvIter1.next === (1, 1.2)) - assert(dvIter1.hasNext === true && dvIter1.next === (2, 3.1)) - assert(dvIter1.hasNext === true && dvIter1.next === (3, 0.0)) - assert(dvIter1.hasNext === false) - - val dvIter2 = dv.activeIterator(true) - assert(dvIter2.hasNext === true && dvIter2.next === (1, 1.2)) - assert(dvIter2.hasNext === true && dvIter2.next === (2, 3.1)) - assert(dvIter2.hasNext === false) - - val svIter1 = sv.activeIterator(false) - assert(svIter1.hasNext === true && svIter1.next === (1, 1.2)) - assert(svIter1.hasNext === true && svIter1.next === (2, 3.1)) - assert(svIter1.hasNext === true && svIter1.next === (3, 0.0)) - assert(svIter1.hasNext === false) - - val svIter2 = sv.activeIterator(true) - assert(svIter2.hasNext === true && svIter2.next === (1, 1.2)) - assert(svIter2.hasNext === true && svIter2.next === (2, 3.1)) - assert(svIter2.hasNext === false) - - // Testing `foreach` + val dvMap0 = scala.collection.mutable.Map[Int, Double]() + dv.foreach() { + case (index: Int, value: Double) => dvMap0.put(index, value) + } + assert(dvMap0.size === 4) + assert(dvMap0.get(0) === Some(0.0)) + assert(dvMap0.get(1) === Some(1.2)) + assert(dvMap0.get(2) === Some(3.1)) + assert(dvMap0.get(3) === Some(0.0)) + val dvMap1 = scala.collection.mutable.Map[Int, Double]() - dvIter1.foreach{ + dv.foreach(false) { case (index, value) => dvMap1.put(index, value) } assert(dvMap1.size === 4) @@ -223,16 +198,25 @@ class VectorsSuite extends FunSuite { assert(dvMap1.get(2) === Some(3.1)) assert(dvMap1.get(3) === Some(0.0)) - val dvMap2 = scala.collection.mutable.Map[Int, Double]() - dvIter2.foreach{ + val dvMap2 = scala.collection .mutable.Map[Int, Double]() + dv.foreach(true) { case (index, value) => dvMap2.put(index, value) } assert(dvMap2.size === 2) assert(dvMap2.get(1) === Some(1.2)) assert(dvMap2.get(2) === Some(3.1)) + val svMap0 = scala.collection.mutable.Map[Int, Double]() + sv.foreach() { + case (index, value) => svMap0.put(index, value) + } + assert(svMap0.size === 3) + assert(svMap0.get(1) === Some(1.2)) + assert(svMap0.get(2) === Some(3.1)) + assert(svMap0.get(3) === Some(0.0)) + val svMap1 = scala.collection.mutable.Map[Int, Double]() - svIter1.foreach{ + sv.foreach(false) { case (index, value) => svMap1.put(index, value) } assert(svMap1.size === 3) @@ -241,12 +225,11 @@ class VectorsSuite extends FunSuite { assert(svMap1.get(3) === Some(0.0)) val svMap2 = scala.collection.mutable.Map[Int, Double]() - svIter2.foreach{ + sv.foreach(true) { case (index, value) => svMap2.put(index, value) } assert(svMap2.size === 2) assert(svMap2.get(1) === Some(1.2)) assert(svMap2.get(2) === Some(3.1)) - } }