Skip to content

Commit

Permalink
[SPARK-4409] Modified genRandMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Dec 19, 2014
1 parent 3971c93 commit f62d6c7
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg

import java.util.{Arrays, Random}

import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map}
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map => MutableMap}

import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}

Expand Down Expand Up @@ -408,8 +408,15 @@ object SparseMatrix {
require(density >= 0.0 && density <= 1.0, "density must be a double in the range " +
s"0.0 <= d <= 1.0. Currently, density: $density")
val length = math.ceil(numRows * numCols * density).toInt
val entries = Map[(Int, Int), Double]()
val entries = MutableMap[(Int, Int), Double]()
var i = 0
if (density == 0.0) {
return new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1),
Array[Int](), Array[Double]())
} else if (density == 1.0) {
return new SparseMatrix(numRows, numCols, (0 to numRows * numCols by numRows).toArray,
(0 until numRows * numCols).toArray, Array.fill(numRows * numCols)(method(rng)))
}
// Expected number of iterations is less than 1.5 * length
if (density < 0.34) {
while (i < length) {
Expand All @@ -424,23 +431,18 @@ object SparseMatrix {
}
} else { // selection - rejection method
var j = 0
val triesPerCol = math.ceil(length * 1.0 / numCols).toInt
val pool = numRows * numCols
// loop over columns so that the sort in fromCOO requires less sorting
while (i < length && j < numCols) {
var k = 0
val leftFromPool = (numCols - j) * numRows
while (k < triesPerCol) {
if (rng.nextDouble() < 1.0 * (length - i) / (pool - leftFromPool)) {
var rowIndex = rng.nextInt(numRows)
val colIndex = j
while (entries.contains((rowIndex, colIndex))) {
rowIndex = rng.nextInt(numRows)
}
entries += (rowIndex, colIndex) -> method(rng)
var passedInPool = j * numRows
var r = 0
while (i < length && r < numRows) {
if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
entries += (r, j) -> method(rng)
i += 1
}
k += 1
r += 1
passedInPool += 1
}
j += 1
}
Expand Down

0 comments on commit f62d6c7

Please sign in to comment.