From b2851066c88f5436e18877ee2eb337b522c08512 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Mon, 21 Apr 2014 13:47:09 -0400 Subject: [PATCH] Clean up NNLS test cases. --- .../spark/mllib/optimization/NNLSSuite.scala | 53 ++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index f6f054741721a..7f6e828a10b51 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -20,40 +20,55 @@ package org.apache.spark.mllib.optimization import scala.util.Random import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.mllib.util.LocalSparkContext +import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} -import org.jblas.DoubleMatrix -import org.jblas.SimpleBlas - -class NNLSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { +class NNLSSuite extends FunSuite { test("NNLSbyPCG: exact solution case") { - val A = new DoubleMatrix(20, 20) - val b = new DoubleMatrix(20, 1) + val n = 20 + val A = new DoubleMatrix(n, n) + val b = new DoubleMatrix(n, 1) val rand = new Random(12345) - for (i <- 0 until 20; j <- 0 until 20) { + for (i <- 0 until n; j <- 0 until n) { val aij = rand.nextDouble() A.put(i, j, aij) b.put(i, b.get(i, 0) + aij) } - val ata = new DoubleMatrix(20, 20) - val atb = new DoubleMatrix(20, 1) - for (i <- 0 until 20; j <- 0 until 20; k <- 0 until 20) { - ata.put(i, j, ata.get(i, j) + A.get(k, i) * A.get(k, j)) - } - for (i <- 0 until 20; j <- 0 until 20) { - atb.put(i, atb.get(i, 0) + A.get(j, i) * b.get(j)) - } + val ata = new DoubleMatrix(n, n) + val atb = new DoubleMatrix(n, 1) + + NativeBlas.dgemm('T', 'N', n, n, n, 1.0, A.data, 0, n, A.data, 0, n, 0.0, ata.data, 0, n) + NativeBlas.dgemv('T', n, n, 1.0, A.data, 0, n, b.data, 0, 1, 0.0, atb.data, 0, 1) val x = NNLSbyPCG.solve(ata, atb, true) - assert(x.length == 20) + assert(x.length == n) var error = 0.0 - for (i <- 0 until 20) { + for (i <- 0 until n) { error = error + (x(i) - 1) * (x(i) - 1) assert(Math.abs(x(i) - 1) < 1e-3) } assert(error < 1e-2) } + + test("NNLSbyPCG: nonnegativity constraint active") { + val n = 5 + val M = Array( + Array( 4.377, -3.531, -1.306, -0.139, 3.418, -1.632), + Array(-3.531, 4.344, 0.934, 0.305, -2.140, 2.115), + Array(-1.306, 0.934, 2.644, -0.203, -0.170, 1.094), + Array(-0.139, 0.305, -0.203, 5.883, 1.428, -1.025), + Array( 3.418, -2.140, -0.170, 1.428, 4.684, -0.636)) + val ata = new DoubleMatrix(5, 5) + val atb = new DoubleMatrix(5, 1) + for (i <- 0 until 5; j <- 0 until 5) ata.put(i, j, M(i)(j)) + for (i <- 0 until 5) atb.put(i, M(i)(5)) + + val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) + + val x = NNLSbyPCG.solve(ata, atb, true) + for (i <- 0 until 5) { + assert(Math.abs(x(i) - goodx(i)) < 1e-3) + } + } }