diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index fbdf53ec28963..8f9418deb045a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -62,6 +62,9 @@ class NaiveBayesModel private[mllib] ( val theta: Array[Array[Double]], val modelType: NaiveBayesModels) extends ClassificationModel with Serializable with Saveable { + def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + this(labels, pi, theta, NaiveBayesModels.Multinomial) + private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t