Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added model save/load version to support NaiveBayes ModelType #2

Merged
merged 2 commits into from
Mar 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}

import NaiveBayes.ModelType.{Bernoulli, Multinomial}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] (
extends ClassificationModel with Serializable with Saveable {

private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, NaiveBayes.Multinomial)
this(labels, pi, theta, Multinomial)

/** A Java-friendly constructor that takes three Iterable parameters. */
private[mllib] def this(
Expand All @@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] (
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function).
private val (brzNegTheta, brzNegThetaSum) = modelType match {
case NaiveBayes.Multinomial => (None, None)
case NaiveBayes.Bernoulli =>
case Multinomial => (None, None)
case Bernoulli =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
(Option(negTheta), Option(brzSum(negTheta, Axis._1)))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}

override def predict(testData: RDD[Vector]): RDD[Double] = {
Expand All @@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] (

override def predict(testData: Vector): Double = {
modelType match {
case NaiveBayes.Multinomial =>
case Multinomial =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case NaiveBayes.Bernoulli =>
case Bernoulli =>
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
}
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType.toString)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
override protected def formatVersion: String = "2.0"
}

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {
private[mllib] object SaveLoadV2_0 {

def thisFormatVersion: String = "1.0"
def thisFormatVersion: String = "2.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
Expand All @@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length) ~
("modelType" -> data.modelType)))
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
Expand All @@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
new NaiveBayesModel(labels, pi, theta, modelType)
}

}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
implicit val formats = DefaultFormats
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
private[mllib] object SaveLoadV1_0 {

def thisFormatVersion: String = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
case class Data(
labels: Array[Double],
pi: Array[Double],
theta: Array[Array[Double]])

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// Create JSON metadata.
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
val classNameV2_0 = SaveLoadV2_0.thisClassName
val (model, numFeatures, numClasses) = (loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
assert(model.modelType == getModelType(metadata))
model
(model, numFeatures, numClasses)
case (className, "2.0") if className == classNameV2_0 =>
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV2_0.load(sc, path)
(model, numFeatures, numClasses)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
model
}
}

Expand All @@ -197,9 +250,9 @@ class NaiveBayes private (
private var lambda: Double,
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {

def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
def this(lambda: Double) = this(lambda, Multinomial)

def this() = this(1.0, NaiveBayes.Multinomial)
def this() = this(1.0, Multinomial)

/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
Expand All @@ -210,9 +263,22 @@ class NaiveBayes private (
/** Get the smoothing parameter. */
def getLambda: Double = lambda

/** Set the model type. Default: Multinomial. */
def setModelType(model: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = model
/**
* Set the model type using a string (case-insensitive).
* Supported options: "multinomial" and "bernoulli".
* (default: multinomial)
*/
def setModelType(modelType: String): NaiveBayes = {
setModelType(NaiveBayes.ModelType.fromString(modelType))
}

/**
* Set the model type.
* Supported options: [[NaiveBayes.ModelType.Bernoulli]], [[NaiveBayes.ModelType.Multinomial]]
* (default: Multinomial)
*/
def setModelType(modelType: NaiveBayes.ModelType): NaiveBayes = {
this.modelType = modelType
this
}

Expand Down Expand Up @@ -270,8 +336,11 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case NaiveBayes.Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case NaiveBayes.Bernoulli => math.log(n + 2.0 * lambda)
case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
}
var j = 0
while (j < numFeatures) {
Expand Down Expand Up @@ -317,7 +386,7 @@ object NaiveBayes {
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda, NaiveBayes.Multinomial).run(input)
new NaiveBayes(lambda, NaiveBayes.ModelType.Multinomial).run(input)
}

/**
Expand All @@ -339,35 +408,42 @@ object NaiveBayes {
* multinomial or bernoulli
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
new NaiveBayes(lambda, MODELTYPE.fromString(modelType)).run(input)
new NaiveBayes(lambda, ModelType.fromString(modelType)).run(input)
}

/** Provides static methods for using ModelType. */
sealed abstract class ModelType extends Serializable

object MODELTYPE extends Serializable{
final val MULTINOMIAL_STRING = "multinomial"
final val BERNOULLI_STRING = "bernoulli"
object ModelType extends Serializable {

def fromString(modelType: String): ModelType = modelType match {
case MULTINOMIAL_STRING => Multinomial
case BERNOULLI_STRING => Bernoulli
/**
* Get the model type from a string.
* @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
*/
def fromString(modelType: String): ModelType = modelType.toLowerCase match {
case "multinomial" => Multinomial
case "bernoulli" => Bernoulli
case _ =>
throw new IllegalArgumentException(s"Cannot recognize NaiveBayes ModelType: $modelType")
throw new IllegalArgumentException(
s"NaiveBayes.ModelType.fromString did not recognize string: $modelType")
}
}

final val ModelType = MODELTYPE
final val Multinomial: ModelType = {
case object Multinomial extends ModelType with Serializable {
override def toString: String = "multinomial"
}
Multinomial
}

/** Constant for specifying ModelType parameter: multinomial model */
final val Multinomial: ModelType = new ModelType {
override def toString: String = ModelType.MULTINOMIAL_STRING
final val Bernoulli: ModelType = {
case object Bernoulli extends ModelType with Serializable {
override def toString: String = "bernoulli"
}
Bernoulli
}
}

/** Constant for specifying ModelType parameter: bernoulli model */
final val Bernoulli: ModelType = new ModelType {
override def toString: String = ModelType.BERNOULLI_STRING
}
/** Java-friendly accessor for supported ModelType options */
final val modelTypes = ModelType

}

Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@

package org.apache.spark.mllib.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
Expand Down Expand Up @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception {
// Should be able to get the first prediction.
predictions.first();
}

@Test
public void testModelTypeSetters() {
NaiveBayes nb = new NaiveBayes()
.setModelType(NaiveBayes.modelTypes().Bernoulli())
.setModelType(NaiveBayes.modelTypes().Multinomial());
}
}
Loading