Skip to content

Commit

Permalink
Many cleanups after code review. Major changes: Storing numFeatures, …
Browse files Browse the repository at this point in the history
…numClasses in model metadata. Improvements to unit tests
  • Loading branch information
jkbradley committed Feb 4, 2015
1 parent b4ee064 commit 12d9059
Show file tree
Hide file tree
Showing 17 changed files with 299 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

/**
* :: Experimental ::
Expand Down Expand Up @@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}

private[mllib] object ClassificationModel {

/**
* Helper method for loading GLM classification model metadata.
*
* @param modelClass String name for model class (used for error messages)
* @return (numFeatures, numClasses)
*/
def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
metadata.select("numFeatures", "numClasses").take(1)(0) match {
case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
case _ => throw new Exception(s"$modelClass unable to load" +
s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ class LogisticRegressionModel (
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

if (numClasses == 2) {
require(weights.size == numFeatures,
s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" +
s" numFeatures = $numFeatures, but weights.size = ${weights.size}")
} else {
val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures
val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1)
require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept,
s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" +
s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" +
s" or $weightsSizeWithIntercept (with intercept)," +
s" but was given weights of length ${weights.size}")
}

def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)

private var threshold: Option[Double] = Some(0.5)
Expand Down Expand Up @@ -81,7 +95,9 @@ class LogisticRegressionModel (
this
}

override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
require(dataMatrix.size == numFeatures)

Expand Down Expand Up @@ -140,7 +156,7 @@ class LogisticRegressionModel (

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
weights, intercept, threshold)
numFeatures, numClasses, weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
Expand All @@ -150,11 +166,16 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new LogisticRegressionModel(data.weights, data.intercept)
// numFeatures, numClasses, weights are checked in model initialization
val model =
new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.apache.spark.mllib.classification.impl.GLMClassificationModel

import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
Expand Down Expand Up @@ -78,10 +79,13 @@ class NaiveBayesModel private[mllib] (

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import Loader._

private object SaveLoadV1_0 {

def thisFormatVersion = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
Expand All @@ -93,22 +97,23 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((thisClassName, thisFormatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")
sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
.toDataFrame("class", "version", "numFeatures", "numClasses")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.repartition(1).saveAsParquetFile(path + "/data")
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(path + "/data")
val dataRDD = sqlContext.parquetFile(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataRDD.schema)
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
Expand All @@ -118,11 +123,24 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
model
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class SVMModel (

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
weights, intercept, threshold)
numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
Expand All @@ -93,11 +93,18 @@ object SVMModel extends Loader[SVMModel] {

override def load(sc: SparkContext, path: String): SVMModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
s" was given non-matching weights vector of size ${model.weights.size}")
assert(numClasses == 2,
s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes")
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.classification.impl

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

/**
Expand All @@ -30,13 +31,20 @@ private[classification] object GLMClassificationModel {

def thisFormatVersion = "1.0"

/** Model data for model import/export */
/** Model data for import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])

/**
* Helper method for saving GLM classification model metadata and data.
* @param modelClass String name for model class, to be saved with metadata
* @param numClasses Number of classes label can take, to be saved with metadata
*/
def save(
sc: SparkContext,
path: String,
modelClass: String,
numFeatures: Int,
numClasses: Int,
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
Expand All @@ -45,23 +53,32 @@ private[classification] object GLMClassificationModel {

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")
sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
.toDataFrame("class", "version", "numFeatures", "numClasses")
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))

// Create Parquet data.
val data = Data(weights, intercept, threshold)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(path + "/data")
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

/**
* Helper method for loading GLM classification model data.
*
* NOTE: Callers of this method should check numClasses, numFeatures on their own.
*
* @param modelClass String name for model class (used for error messages)
*/
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val datapath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(path + "/data")
val dataRDD = sqlContext.parquetFile(datapath)
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}")
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
val data = dataArray(0)
assert(data.size == 3, s"Unable to load $modelClass data from: ${path + "/data"}")
assert(data.size == 3, s"Unable to load $modelClass data from: $datapath")
val (weights, intercept) = data match {
case Row(weights: Vector, intercept: Double, _) =>
(weights, intercept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ object LassoModel extends Loader[LassoModel] {

override def load(sc: SparkContext, path: String): LassoModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LassoModel(data.weights, data.intercept)
case _ => throw new Exception(
s"LassoModel.load did not recognize model with (className, format version):" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {

override def load(sc: SparkContext, path: String): LinearRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LinearRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
s"LinearRegressionModel.load did not recognize model with (className, format version):" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.mllib.regression

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}

@Experimental
trait RegressionModel extends Serializable {
Expand Down Expand Up @@ -48,3 +50,21 @@ trait RegressionModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}

private[mllib] object RegressionModel {

/**
* Helper method for loading GLM regression model metadata.
*
* @param modelClass String name for model class (used for error messages)
* @return numFeatures
*/
def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
metadata.select("numFeatures").take(1)(0) match {
case Row(nFeatures: Int) => nFeatures
case _ => throw new Exception(s"$modelClass unable to load" +
s" numFeatures from metadata: ${Loader.metadataPath(path)}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {

override def load(sc: SparkContext, path: String): RidgeRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new RidgeRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
s"RidgeRegressionModel.load did not recognize model with (className, format version):" +
Expand Down
Loading

0 comments on commit 12d9059

Please sign in to comment.