From e01656978174f8ecbd75ef6a50211234a1babfc6 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Thu, 5 Mar 2015 11:28:05 -0800 Subject: [PATCH] updated test suite with model type fix --- .../mllib/classification/NaiveBayesSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index eceea68a0284b..0874bb0b90ce4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -25,7 +25,6 @@ import scala.util.Random import org.scalatest.FunSuite import org.apache.spark.SparkException -import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -49,7 +48,7 @@ object NaiveBayesSuite { theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int, - dataModel: NaiveBayesModels= NaiveBayesModels.Multinomial, + dataModel: NaiveBayes.ModelType = NaiveBayes.Multinomial, sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) @@ -60,10 +59,10 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = dataModel match { - case NaiveBayesModels.Bernoulli => Array.tabulate[Double] (D) {j => + case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) {j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } - case NaiveBayesModels.Multinomial => + case NaiveBayes.Multinomial => val mult = Multinomial(BDV(_theta(y))) val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { @@ -78,7 +77,7 @@ object NaiveBayesSuite { /** Binary labels, 3 features */ private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), - theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayesModels.Bernoulli) + theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -121,7 +120,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 42, NaiveBayesModels.Multinomial) + pi, theta, nPoints, 42, NaiveBayes.Multinomial) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -133,7 +132,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { theta, nPoints, 17, - NaiveBayesModels.Multinomial) + NaiveBayes.Multinomial) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -158,7 +157,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { theta, nPoints, 45, - NaiveBayesModels.Bernoulli) + NaiveBayes.Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -170,7 +169,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { theta, nPoints, 20, - NaiveBayesModels.Bernoulli) + NaiveBayes.Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD.