diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 07a01b61b5b59..bbf385229081a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -24,21 +24,13 @@ import org.scalatest.FunSuite import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} class NNLSSuite extends FunSuite { - /** Generate a NNLS problem whose optimal solution is the all-ones vector. */ + /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { - val A = new DoubleMatrix(n, n) - val b = new DoubleMatrix(n, 1) - for (i <- 0 until n; j <- 0 until n) { - val aij = rand.nextDouble() - A.put(i, j, aij) - b.put(i, b.get(i, 0) + aij) - } - - val ata = new DoubleMatrix(n, n) - val atb = new DoubleMatrix(n, 1) + val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) + val b = A.mmul(DoubleMatrix.ones(n, 1)) - NativeBlas.dgemm('T', 'N', n, n, n, 1.0, A.data, 0, n, A.data, 0, n, 0.0, ata.data, 0, n) - NativeBlas.dgemv('T', n, n, 1.0, A.data, 0, n, b.data, 0, 1, 0.0, atb.data, 0, 1) + val ata = A.transpose.mmul(A) + val atb = A.transpose.mmul(b) (ata, atb) } @@ -53,17 +45,13 @@ class NNLSSuite extends FunSuite { // can legitimately fail to solve these anywhere close to exactly. So we grab a considerable // sample of these matrices and make sure that we solved a substantial fraction of them. - for (kase <- 0 until 100) { + for (k <- 0 until 100) { val (ata, atb) = genOnesData(n, rand) - val x = NNLS.solve(ata, atb, ws) + val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) assert(x.length === n) - var error = 0.0 - var solved = true - for (i <- 0 until n) { - error = error + (x(i) - 1) * (x(i) - 1) - if (Math.abs(x(i) - 1) > 1e-3) solved = false - } - if (error > 1e-2) solved = false + val answer = DoubleMatrix.ones(n, 1) + SimpleBlas.axpy(-1.0, answer, x) + val solved = (x.norm2 < 1e-2) && (x.normmax < 1e-3) if (solved) numSolved = numSolved + 1 } @@ -72,16 +60,13 @@ class NNLSSuite extends FunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 - val M = Array( - Array( 4.377, -3.531, -1.306, -0.139, 3.418, -1.632), - Array(-3.531, 4.344, 0.934, 0.305, -2.140, 2.115), - Array(-1.306, 0.934, 2.644, -0.203, -0.170, 1.094), - Array(-0.139, 0.305, -0.203, 5.883, 1.428, -1.025), - Array( 3.418, -2.140, -0.170, 1.428, 4.684, -0.636)) - val ata = new DoubleMatrix(n, n) - val atb = new DoubleMatrix(n, 1) - for (i <- 0 until n; j <- 0 until n) ata.put(i, j, M(i)(j)) - for (i <- 0 until n) atb.put(i, M(i)(n)) + val ata = new DoubleMatrix(Array( + Array( 4.377, -3.531, -1.306, -0.139, 3.418), + Array(-3.531, 4.344, 0.934, 0.305, -2.140), + Array(-1.306, 0.934, 2.644, -0.203, -0.170), + Array(-0.139, 0.305, -0.203, 5.883, 1.428), + Array( 3.418, -2.140, -0.170, 1.428, 4.684))) + val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628)