diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index cacea5b1eced6..6fc88754e4a17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -35,8 +35,6 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} -import NaiveBayes.ModelType.{Bernoulli, Multinomial} - /** * Model for Naive Bayes Classifiers. @@ -45,18 +43,17 @@ import NaiveBayes.ModelType.{Bernoulli, Multinomial} * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features - * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be - * Multinomial or Bernoulli + * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], val theta: Array[Array[Double]], - val modelType: NaiveBayes.ModelType) + val modelType: String) extends ClassificationModel with Serializable with Saveable { private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = - this(labels, pi, theta, Multinomial) + this(labels, pi, theta, "Multinomial") /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -72,8 +69,8 @@ class NaiveBayesModel private[mllib] ( // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // application of this condition (in predict function). private val (brzNegTheta, brzNegThetaSum) = modelType match { - case Multinomial => (None, None) - case Bernoulli => + case "Multinomial" => (None, None) + case "Bernoulli" => val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) (Option(negTheta), Option(brzSum(negTheta, Axis._1))) case _ => @@ -91,9 +88,9 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { - case Multinomial => + case "Multinomial" => labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) - case Bernoulli => + case "Bernoulli" => labels (brzArgmax (brzPi + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) case _ => @@ -103,7 +100,7 @@ class NaiveBayesModel private[mllib] ( } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString) + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) } @@ -155,7 +152,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray - val modelType = NaiveBayes.ModelType.fromString(data.getString(3)) + val modelType = data.getString(3) new NaiveBayesModel(labels, pi, theta, modelType) } @@ -248,11 +245,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { class NaiveBayes private ( private var lambda: Double, - private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { + private var modelType: String) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, Multinomial) + def this(lambda: Double) = this(lambda, "Multinomial") - def this() = this(1.0, Multinomial) + def this() = this(1.0, "Multinomial") /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -264,26 +261,21 @@ class NaiveBayes private ( def getLambda: Double = lambda /** - * Set the model type using a string (case-insensitive). - * Supported options: "multinomial" and "bernoulli". - * (default: multinomial) - */ - def setModelType(modelType: String): NaiveBayes = { - setModelType(NaiveBayes.ModelType.fromString(modelType)) - } - - /** - * Set the model type. - * Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]] + * Set the model type using a string (case-sensitive). + * Supported options: "Multinomial" and "Bernoulli". * (default: Multinomial) */ - def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = { - this.modelType = modelType - this + def setModelType(modelType:String): NaiveBayes = { + if (NaiveBayes.supportedModelTypes.contains(modelType)) { + this.modelType = modelType + this + } else { + throw new UnknownError(s"NaiveBayesModel does not support ModelType: $modelType") + } } /** Get the model type. */ - def getModelType: NaiveBayes.ModelType = this.modelType + def getModelType: String = this.modelType /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. @@ -336,8 +328,8 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) - case Bernoulli => math.log(n + 2.0 * lambda) + case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Bernoulli" => math.log(n + 2.0 * lambda) case _ => // This should never happen. throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") @@ -358,6 +350,10 @@ class NaiveBayes private ( * Top-level methods for calling naive Bayes. */ object NaiveBayes { + + /* Set of modelTypes that NaiveBayes supports */ + private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * @@ -386,7 +382,7 @@ object NaiveBayes { * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input) + new NaiveBayes(lambda, "Multinomial").run(input) } /** @@ -408,42 +404,11 @@ object NaiveBayes { * multinomial or bernoulli */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { - new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input) - } - - /** Provides static methods for using ModelType. */ - sealed abstract class ModelType extends Serializable - - object ModelType extends Serializable { - - /** - * Get the model type from a string. - * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive) - */ - def fromString(modelType: String): ModelType = modelType.toLowerCase match { - case "multinomial" => Multinomial - case "bernoulli" => Bernoulli - case _ => - throw new IllegalArgumentException( - s"NaiveBayes.ModelType.fromString did not recognize string: $modelType") - } - - final val Multinomial: ModelType = { - case object Multinomial extends ModelType with Serializable { - override def toString: String = "multinomial" - } - Multinomial - } - - final val Bernoulli: ModelType = { - case object Bernoulli extends ModelType with Serializable { - override def toString: String = "bernoulli" - } - Bernoulli + if (supportedModelTypes.contains(modelType)) { + new NaiveBayes(lambda, modelType).run(input) + } else { + throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") } } - /** Java-friendly accessor for supported ModelType options */ - final val modelTypes = ModelType - } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 4d89c06b88c0e..71fb7f13c39c2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception { @Test public void testModelTypeSetters() { NaiveBayes nb = new NaiveBayes() - .setModelType(NaiveBayes.modelTypes().Bernoulli()) - .setModelType(NaiveBayes.modelTypes().Multinomial()); + .setModelType("Bernoulli") + .setModelType("Multinomial"); } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 2d87d6893250b..f9fe3e006ccb8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -25,7 +25,6 @@ import breeze.stats.distributions.{Multinomial => BrzMultinomial} import org.scalatest.FunSuite import org.apache.spark.SparkException -import org.apache.spark.mllib.classification.NaiveBayes.ModelType.{Bernoulli, Multinomial} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -49,7 +48,7 @@ object NaiveBayesSuite { theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int, - modelType: NaiveBayes.ModelType = Multinomial, + modelType: String = "Multinomial", sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) @@ -59,10 +58,10 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { - case Bernoulli => Array.tabulate[Double] (D) { j => + case "Bernoulli" => Array.tabulate[Double] (D) { j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } - case Multinomial => + case "Multinomial" => val mult = BrzMultinomial(BDV(_theta(y))) val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { @@ -81,12 +80,12 @@ object NaiveBayesSuite { /** Bernoulli NaiveBayes with binary labels, 3 features */ private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - Bernoulli) + "Bernoulli") /** Multinomial NaiveBayes with binary labels, 3 features */ private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - Multinomial) + "Multinomial") } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -136,15 +135,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 42, Multinomial) + pi, theta, nPoints, 42, "Multinomial") val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "multinomial") + val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 17, Multinomial) + pi, theta, nPoints, 17, "Multinomial") val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -164,15 +163,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 45, Bernoulli) + pi, theta, nPoints, 45, "Bernoulli") val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "bernoulli") + val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 20, Bernoulli) + pi, theta, nPoints, 20, "Bernoulli") val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -243,7 +242,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) - assert(model.modelType === NaiveBayes.ModelType.Multinomial) + assert(model.modelType === "Multinomial") } finally { Utils.deleteRecursively(tempDir) }