From 561d31d2f13cc7b1112ba9f9aa8f08bcd032aebb Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Nov 2014 08:22:50 -0800 Subject: [PATCH] [SPARK-4614][MLLIB] Slight API changes in Matrix and Matrices Before we have a full picture of the operators we want to add, it might be safer to hide `Matrix.transposeMultiply` in 1.2.0. Another update we want to change is `Matrix.randn` and `Matrix.rand`, both of which should take a `Random` implementation. Otherwise, it is very likely to produce inconsistent RDDs. I also added some unit tests for matrix factory methods. All APIs are new in 1.2, so there is no incompatible changes. brkyvz Author: Xiangrui Meng Closes #3468 from mengxr/SPARK-4614 and squashes the following commits: 3b0e4e2 [Xiangrui Meng] add mima excludes 6bfd8a4 [Xiangrui Meng] hide transposeMultiply; add rng to rand and randn; add unit tests --- .../apache/spark/mllib/linalg/Matrices.scala | 20 ++++---- .../spark/mllib/linalg/MatricesSuite.scala | 50 +++++++++++++++++++ project/MimaExcludes.scala | 6 +++ 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 2cc52e94282ba..327366a1a3a82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,12 +17,10 @@ package org.apache.spark.mllib.linalg -import java.util.Arrays +import java.util.{Random, Arrays} import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} -import org.apache.spark.util.random.XORShiftRandom - /** * Trait for a local matrix. */ @@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - def transposeMultiply(y: DenseMatrix): DenseMatrix = { + private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] BLAS.gemm(true, false, 1.0, this, y, 0.0, C) C } /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - def transposeMultiply(y: DenseVector): DenseVector = { + private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { val output = new DenseVector(new Array[Double](numCols)) BLAS.gemv(true, 1.0, this, y, 0.0, output) output @@ -291,22 +289,22 @@ object Matrices { * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int): Matrix = { - val rand = new XORShiftRandom - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble())) + def rand(numRows: Int, numCols: Int, rng: Random): Matrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int): Matrix = { - val rand = new XORShiftRandom - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian())) + def randn(numRows: Int, numCols: Int, rng: Random): Matrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 5f8b8c4b72697..322a0e9242918 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -17,7 +17,11 @@ package org.apache.spark.mllib.linalg +import java.util.Random + +import org.mockito.Mockito.when import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -112,4 +116,50 @@ class MatricesSuite extends FunSuite { assert(sparseMat(0, 1) === 10.0) assert(sparseMat.values(2) === 10.0) } + + test("zeros") { + val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 0.0)) + } + + test("ones") { + val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 1.0)) + } + + test("eye") { + val mat = Matrices.eye(2).asInstanceOf[DenseMatrix] + assert(mat.numCols === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0)) + } + + test("rand") { + val rng = mock[Random] + when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("randn") { + val rng = mock[Random] + when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("diag") { + val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0)) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 94de14ddbd2bb..230239aa40500 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -47,6 +47,12 @@ object MimaExcludes { "org.apache.spark.SparkStageInfoImpl.this"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkStageInfo.submissionTime") + ) ++ Seq( + // SPARK-4614 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.randn"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.rand") ) case v if v.startsWith("1.2") =>