diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index a2c8fb3bf2b3c..b8e869b6cc30e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -136,21 +136,19 @@ private class ColumnStatisticsAggregator(private val n: Int) var i = 0 while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { + if (nnz(i) + other.nnz(i) != 0.0) { + // merge mean together currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { + // merge m2n together currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / (nnz(i) + other.nnz(i)) - } - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) + if (currMax(i) < other.currMax(i)) { + currMax(i) = other.currMax(i) + } + if (currMin(i) > other.currMin(i)) { + currMin(i) = other.currMin(i) + } } i += 1 }