Skip to content

Commit

Permalink
Cleanup matrix math in NNLSSuite.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmyklebu committed Apr 25, 2014
1 parent 65ef7f2 commit 7fbabf1
Showing 1 changed file with 17 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,13 @@ import org.scalatest.FunSuite
import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas}

class NNLSSuite extends FunSuite {
/** Generate a NNLS problem whose optimal solution is the all-ones vector. */
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = {
val A = new DoubleMatrix(n, n)
val b = new DoubleMatrix(n, 1)
for (i <- 0 until n; j <- 0 until n) {
val aij = rand.nextDouble()
A.put(i, j, aij)
b.put(i, b.get(i, 0) + aij)
}

val ata = new DoubleMatrix(n, n)
val atb = new DoubleMatrix(n, 1)
val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*)
val b = A.mmul(DoubleMatrix.ones(n, 1))

NativeBlas.dgemm('T', 'N', n, n, n, 1.0, A.data, 0, n, A.data, 0, n, 0.0, ata.data, 0, n)
NativeBlas.dgemv('T', n, n, 1.0, A.data, 0, n, b.data, 0, 1, 0.0, atb.data, 0, 1)
val ata = A.transpose.mmul(A)
val atb = A.transpose.mmul(b)

(ata, atb)
}
Expand All @@ -53,17 +45,13 @@ class NNLSSuite extends FunSuite {
// can legitimately fail to solve these anywhere close to exactly. So we grab a considerable
// sample of these matrices and make sure that we solved a substantial fraction of them.

for (kase <- 0 until 100) {
for (k <- 0 until 100) {
val (ata, atb) = genOnesData(n, rand)
val x = NNLS.solve(ata, atb, ws)
val x = new DoubleMatrix(NNLS.solve(ata, atb, ws))
assert(x.length === n)
var error = 0.0
var solved = true
for (i <- 0 until n) {
error = error + (x(i) - 1) * (x(i) - 1)
if (Math.abs(x(i) - 1) > 1e-3) solved = false
}
if (error > 1e-2) solved = false
val answer = DoubleMatrix.ones(n, 1)
SimpleBlas.axpy(-1.0, answer, x)
val solved = (x.norm2 < 1e-2) && (x.normmax < 1e-3)
if (solved) numSolved = numSolved + 1
}

Expand All @@ -72,16 +60,13 @@ class NNLSSuite extends FunSuite {

test("NNLS: nonnegativity constraint active") {
val n = 5
val M = Array(
Array( 4.377, -3.531, -1.306, -0.139, 3.418, -1.632),
Array(-3.531, 4.344, 0.934, 0.305, -2.140, 2.115),
Array(-1.306, 0.934, 2.644, -0.203, -0.170, 1.094),
Array(-0.139, 0.305, -0.203, 5.883, 1.428, -1.025),
Array( 3.418, -2.140, -0.170, 1.428, 4.684, -0.636))
val ata = new DoubleMatrix(n, n)
val atb = new DoubleMatrix(n, 1)
for (i <- 0 until n; j <- 0 until n) ata.put(i, j, M(i)(j))
for (i <- 0 until n) atb.put(i, M(i)(n))
val ata = new DoubleMatrix(Array(
Array( 4.377, -3.531, -1.306, -0.139, 3.418),
Array(-3.531, 4.344, 0.934, 0.305, -2.140),
Array(-1.306, 0.934, 2.644, -0.203, -0.170),
Array(-0.139, 0.305, -0.203, 5.883, 1.428),
Array( 3.418, -2.140, -0.170, 1.428, 4.684)))
val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636))

val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628)

Expand Down

0 comments on commit 7fbabf1

Please sign in to comment.