Skip to content

Commit

Permalink
Reorganized save/load for regression and classification. Renamed conc…
Browse files Browse the repository at this point in the history
…epts to Saveable, Loader
  • Loading branch information
jkbradley committed Feb 4, 2015
1 parent a34aef5 commit b4ee064
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable}
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


Expand All @@ -46,7 +46,7 @@ class LogisticRegressionModel (
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Exportable {
with Saveable {

def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)

Expand Down Expand Up @@ -139,27 +139,33 @@ class LogisticRegressionModel (
}

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.save(sc, path, this.getClass.getName, weights, intercept, threshold)
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
weights, intercept, threshold)
}

override protected def formatVersion: String = LogisticRegressionModel.formatVersion

override protected def formatVersion: String = "1.0"
}

object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
object LogisticRegressionModel extends Loader[LogisticRegressionModel] {

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val data = GLMClassificationModel.loadData(sc, path, classOf[LogisticRegressionModel].getName)
val lr = new LogisticRegressionModel(data.weights, data.intercept)
data.threshold match {
case Some(t) => lr.setThreshold(t)
case None => lr.clearThreshold()
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new LogisticRegressionModel(data.weights, data.intercept)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
lr
}

override protected def formatVersion: String = GLMClassificationModel.formatVersion

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgma
import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Importable, Exportable}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}

Expand All @@ -38,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Exportable {
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
Expand Down Expand Up @@ -69,37 +69,44 @@ class NaiveBayesModel private[mllib] (
}

override def save(sc: SparkContext, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((this.getClass.getName, formatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")

// Create Parquet data.
val data = NaiveBayesModel.Data(labels, pi, theta)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.repartition(1).saveAsParquetFile(path + "/data")
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

override protected def formatVersion: String = NaiveBayesModel.formatVersion

override protected def formatVersion: String = "1.0"
}

object NaiveBayesModel extends Importable[NaiveBayesModel] {
object NaiveBayesModel extends Loader[NaiveBayesModel] {

private object SaveLoadV1_0 {

def thisFormatVersion = "1.0"

def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
private case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])

private object ImporterV1 extends Importer {
def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((thisClassName, thisFormatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.repartition(1).saveAsParquetFile(path + "/data")
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(path + "/data")
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Importable.checkSchema[Data](dataRDD.schema)
Loader.checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
val data = dataArray(0)
Expand All @@ -110,27 +117,18 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] {
}
}

protected object Importer {

def get(clazz: String, version: String): Importer = {
assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" +
s" was given model file with metadata specifying a different model class: $clazz")
version match {
case "1.0" => ImporterV1
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model format version: $version." +
s" Supported versions: 1.0.")
}
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (clazz, version, metadata) = Importable.loadMetadata(sc, path)
val importer = Importer.get(clazz, version)
importer.load(sc, path)
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}

override protected def formatVersion: String = "1.0"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable}
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


Expand All @@ -37,7 +37,7 @@ class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Exportable {
with Saveable {

private var threshold: Option[Double] = Some(0.0)

Expand Down Expand Up @@ -82,26 +82,33 @@ class SVMModel (
}

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.save(sc, path, this.getClass.getName, weights, intercept, threshold)
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
weights, intercept, threshold)
}

override protected def formatVersion: String = SVMModel.formatVersion
override protected def formatVersion: String = "1.0"
}

object SVMModel extends Importable[SVMModel] {
object SVMModel extends Loader[SVMModel] {

override def load(sc: SparkContext, path: String): SVMModel = {
val data = GLMClassificationModel.loadData(sc, path, classOf[SVMModel].getName)
val lr = new SVMModel(data.weights, data.intercept)
data.threshold match {
case Some(t) => lr.setThreshold(t)
case None => lr.clearThreshold()
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"SVMModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
lr
}

override protected def formatVersion: String = GLMClassificationModel.formatVersion

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,43 @@ package org.apache.spark.mllib.classification.impl

import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Importable
import org.apache.spark.sql.{Row, DataFrame, SQLContext}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

/**
* Helper methods for import/export of GLM classification models.
* Helper class for import/export of GLM classification models.
*/
private[classification] object GLMClassificationModel {

/** Model data for model import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
object SaveLoadV1_0 {

def save(
sc: SparkContext,
path: String,
modelClass: String,
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._
def thisFormatVersion = "1.0"

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((modelClass, formatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")
/** Model data for model import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])

// Create Parquet data.
val data = Data(weights, intercept, threshold)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(path + "/data")
}
def save(
sc: SparkContext,
path: String,
modelClass: String,
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

private object ImporterV1 {
// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version")
metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata")

// Create Parquet data.
val data = Data(weights, intercept, threshold)
val dataRDD: DataFrame = sc.parallelize(Seq(data))
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(path + "/data")
}

def load(sc: SparkContext, path: String, modelClass: String): Data = {
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(path + "/data")
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
Expand All @@ -74,21 +75,4 @@ private[classification] object GLMClassificationModel {
}
}

def formatVersion: String = "1.0"

def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val (clazz, version, metadata) = Importable.loadMetadata(sc, path)
// Note: This check of the class name should happen here since we may eventually want to load
// other classes (such as deprecated versions).
assert(clazz == modelClass, s"$modelClass.load" +
s" was given model file with metadata specifying a different model class: $clazz")
version match {
case "1.0" =>
ImporterV1.load(sc, path, modelClass)
case _ => throw new Exception(
s"$modelClass.load did not recognize model format version: $version." +
s" Supported versions: 1.0.")
}
}

}
25 changes: 16 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Exportable, Importable}
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD

/**
Expand All @@ -34,7 +34,7 @@ class LassoModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable with Exportable {
with RegressionModel with Serializable with Saveable {

override protected def predictPoint(
dataMatrix: Vector,
Expand All @@ -44,20 +44,27 @@ class LassoModel (
}

override def save(sc: SparkContext, path: String): Unit = {
GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept)
GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept)
}

override protected def formatVersion: String = LassoModel.formatVersion
override protected def formatVersion: String = "1.0"
}

object LassoModel extends Importable[LassoModel] {
object LassoModel extends Loader[LassoModel] {

override def load(sc: SparkContext, path: String): LassoModel = {
val data = GLMRegressionModel.loadData(sc, path, classOf[LassoModel].getName)
new LassoModel(data.weights, data.intercept)
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
new LassoModel(data.weights, data.intercept)
case _ => throw new Exception(
s"LassoModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}

override protected def formatVersion: String = LassoModel.formatVersion
}

/**
Expand Down
Loading

0 comments on commit b4ee064

Please sign in to comment.