Skip to content

Commit

Permalink
[WIP] [SPARK-1328] Add vector statistics
Browse files Browse the repository at this point in the history
As with the new vector system in MLlib, we find that it is good to add some new APIs to precess the `RDD[Vector]`. Beside, the former implementation of `computeStat` is not stable which could loss precision, and has the possibility to cause `Nan` in scientific computing, just as said in the [SPARK-1328](https://spark-project.atlassian.net/browse/SPARK-1328).

APIs contain:

* rowMeans(): RDD[Double]
* rowNorm2(): RDD[Double]
* rowSDs(): RDD[Double]
* colMeans(): Vector
* colMeans(size: Int): Vector
* colNorm2(): Vector
* colNorm2(size: Int): Vector
* colSDs(): Vector
* colSDs(size: Int): Vector
* maxOption((Vector, Vector) => Boolean): Option[Vector]
* minOption((Vector, Vector) => Boolean): Option[Vector]
* rowShrink(): RDD[Vector]
* colShrink(): RDD[Vector]

This is working in process now, and some more APIs will add to `LabeledPoint`. Moreover, the implicit declaration will move from `MLUtils` to `MLContext` later.

Author: Xusen Yin <yinxusen@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #268 from yinxusen/vector-statistics and squashes the following commits:

d61363f [Xusen Yin] rebase to latest master
16ae684 [Xusen Yin] fix minor error and remove useless method
10cf5d3 [Xusen Yin] refine some return type
b064714 [Xusen Yin] remove computeStat in MLUtils
cbbefdb [Xiangrui Meng] update multivariate statistical summary interface and clean tests
4eaf28a [Xusen Yin] merge VectorRDDStatistics into RowMatrix
48ee053 [Xusen Yin] fix minor error
e624f93 [Xusen Yin] fix scala style error
1fba230 [Xusen Yin] merge while loop together
69e1f37 [Xusen Yin] remove lazy eval, and minor memory footprint
548e9de [Xusen Yin] minor revision
86522c4 [Xusen Yin] add comments on functions
dc77e38 [Xusen Yin] test sparse vector RDD
18cf072 [Xusen Yin] change def to lazy val to make sure that the computations in function be evaluated only once
f7a3ca2 [Xusen Yin] fix the corner case of maxmin
967d041 [Xusen Yin] full revision with Aggregator class
138300c [Xusen Yin] add new Aggregator class
1376ff4 [Xusen Yin] rename variables and adjust code
4a5c38d [Xusen Yin] add scala doc, refine code and comments
036b7a5 [Xusen Yin] fix the bug of Nan occur
f6e8e9a [Xusen Yin] add sparse vectors test
4cfbadf [Xusen Yin] fix bug of min max
4e4fbd1 [Xusen Yin] separate seqop and combop out as independent functions
a6d5a2e [Xusen Yin] rewrite for only computing non-zero elements
3980287 [Xusen Yin] rename variables
62a2c3e [Xusen Yin] use axpy and in-place if possible
9a75ebd [Xusen Yin] add case class to wrap return values
d816ac7 [Xusen Yin] remove useless APIs
c4651bb [Xusen Yin] remove row-wise APIs and refine code
1338ea1 [Xusen Yin] all-in-one version test passed
cc65810 [Xusen Yin] add parallel mean and variance
9af2e95 [Xusen Yin] refine the code style
ad6c82d [Xusen Yin] add shrink test
e09d5d2 [Xusen Yin] add scala docs and refine shrink method
8ef3377 [Xusen Yin] pass all tests
28cf060 [Xusen Yin] fix error of column means
54b19ab [Xusen Yin] add new API to shrink RDD[Vector]
8c6c0e1 [Xusen Yin] add basic statistics
  • Loading branch information
yinxusen authored and pwendell committed Apr 12, 2014
1 parent 9afaeed commit ce0ce3d
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,146 @@ package org.apache.spark.mllib.linalg.distributed

import java.util

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary

/**
* Column statistics aggregator implementing
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
* together with add() and merge() function.
* A numerically stable algorithm is implemented to compute sample mean and variance:
*[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
* to have time complexity O(nnz) instead of O(n) for each column.
*/
private class ColumnStatisticsAggregator(private val n: Int)
extends MultivariateStatisticalSummary with Serializable {

private val currMean: BDV[Double] = BDV.zeros[Double](n)
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
private var totalCnt = 0.0
private val nnz: BDV[Double] = BDV.zeros[Double](n)
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)

override def mean: Vector = {
val realMean = BDV.zeros[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * nnz(i) / totalCnt
i += 1
}
Vectors.fromBreeze(realMean)
}

override def variance: Vector = {
val realVariance = BDV.zeros[Double](n)

val denominator = totalCnt - 1.0

// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
while (i < currM2n.size) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= denominator
i += 1
}
}

Vectors.fromBreeze(realVariance)
}

override def count: Long = totalCnt.toLong

override def numNonzeros: Vector = Vectors.fromBreeze(nnz)

override def max: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
}

override def min: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
}

/**
* Aggregates a row.
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.
case (i, value) =>
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)

nnz(i) += 1.0
}

totalCnt += 1.0
this
}

/**
* Merges another aggregator.
*/
def merge(other: ColumnStatisticsAggregator): this.type = {
require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")

totalCnt += other.totalCnt
val deltaMean = currMean - other.currMean

var i = 0
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
i += 1
}

nnz += other.nnz
this
}
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -182,13 +314,7 @@ class RowMatrix(
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
)

// Update _m if it is not set, or verify its value.
if (nRows <= 0L) {
nRows = m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
updateNumRows(m)

mean :/= m.toDouble

Expand Down Expand Up @@ -240,6 +366,19 @@ class RowMatrix(
}
}

/**
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
updateNumRows(summary.count)
summary
}

/**
* Multiply this matrix by a local matrix on the right.
*
Expand Down Expand Up @@ -276,6 +415,16 @@ class RowMatrix(
}
mat
}

/** Updates or verfires the number of rows. */
private def updateNumRows(m: Long) {
if (nRows <= 0) {
nRows == m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
}
}

object RowMatrix {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat

import org.apache.spark.mllib.linalg.Vector

/**
* Trait for multivariate statistical summary of a data matrix.
*/
trait MultivariateStatisticalSummary {

/**
* Sample mean vector.
*/
def mean: Vector

/**
* Sample variance vector. Should return a zero vector if the sample size is 1.
*/
def variance: Vector

/**
* Sample size.
*/
def count: Long

/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
def numNonzeros: Vector

/**
* Maximum value of each column.
*/
def max: Vector

/**
* Minimum value of each column.
*/
def min: Vector
}
57 changes: 2 additions & 55 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.mllib.util

import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
squaredDistance => breezeSquaredDistance}
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}

import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.Vectors

/**
* Helper methods to load, save and pre-process data used in ML Lib.
Expand Down Expand Up @@ -158,58 +157,6 @@ object MLUtils {
dataStr.saveAsTextFile(dir)
}

/**
* Utility function to compute mean and standard deviation on a given dataset.
*
* @param data - input data set whose statistics are computed
* @param numFeatures - number of features
* @param numExamples - number of examples in input dataset
*
* @return (yMean, xColMean, xColSd) - Tuple consisting of
* yMean - mean of the labels
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
private[mllib] def computeStats(
data: RDD[LabeledPoint],
numFeatures: Int,
numExamples: Long): (Double, Vector, Vector) = {
val brzData = data.map { case LabeledPoint(label, features) =>
(label, features.toBreeze)
}
val aggStats = brzData.aggregate(
(0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
)(
seqOp = (c, v) => (c, v) match {
case ((n, sumLabel, sum, sumSq), (label, features)) =>
features.activeIterator.foreach { case (i, x) =>
sumSq(i) += x * x
}
(n + 1L, sumLabel + label, sum += features, sumSq)
},
combOp = (c1, c2) => (c1, c2) match {
case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
(n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
}
)
val (nl, sumLabel, sum, sumSq) = aggStats

require(nl > 0, "Input data is empty.")
require(nl == numExamples)

val n = nl.toDouble
val yMean = sumLabel / n
val mean = sum / n
val std = new Array[Double](sum.length)
var i = 0
while (i < numFeatures) {
std(i) = sumSq(i) / n - mean(i) * mean(i)
i += 1
}

(yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
}

/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,19 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
))
}
}

test("compute column summary statistics") {
for (mat <- Seq(denseMat, sparseMat)) {
val summary = mat.computeColumnSummaryStatistics()
// Run twice to make sure no internal states are changed.
for (k <- 0 to 1) {
assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
assert(summary.count === m, "count mismatch.")
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._

class MLUtilsSuite extends FunSuite with LocalSparkContext {
Expand Down Expand Up @@ -56,18 +55,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}

test("compute stats") {
val data = Seq.fill(3)(Seq(
LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
)).flatten
val rdd = sc.parallelize(data, 2)
val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
assert(meanLabel === 0.5)
assert(mean === Vectors.dense(2.0, 3.0, 4.0))
assert(std === Vectors.dense(1.0, 1.0, 1.0))
}

test("loadLibSVMData") {
val lines =
"""
Expand Down

0 comments on commit ce0ce3d

Please sign in to comment.