Skip to content

Commit

Permalink
made version for model import/export local to each model
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 3, 2015
1 parent 1496852 commit c495dba
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class LogisticRegressionModel (

// Create JSON metadata.
val metadata = LogisticRegressionModel.Metadata(
clazz = this.getClass.getName, version = Exportable.latestVersion)
clazz = this.getClass.getName, version = latestVersion)
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")

Expand All @@ -153,6 +153,9 @@ class LogisticRegressionModel (
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.saveAsParquetFile(path + "/data")
}

override protected def latestVersion: String = LogisticRegressionModel.latestVersion

}

object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
Expand All @@ -176,7 +179,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
case Row(clazz: String, version: String) =>
assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" +
s" was given model file with metadata specifying a different model class: $clazz")
assert(version == Exportable.latestVersion, // only 1 version exists currently
assert(version == latestVersion, // only 1 version exists currently
s"LogisticRegressionModel.load did not recognize model format version: $version")
}

Expand All @@ -199,6 +202,8 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
lr
}

override protected def latestVersion: String = "1.0"

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class NaiveBayesModel private[mllib] (

// Create JSON metadata.
val metadata = NaiveBayesModel.Metadata(
clazz = this.getClass.getName, version = Exportable.latestVersion)
clazz = this.getClass.getName, version = latestVersion)
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")

Expand All @@ -87,6 +87,8 @@ class NaiveBayesModel private[mllib] (
val dataRDD: DataFrame = sc.parallelize(Seq(data))
dataRDD.saveAsParquetFile(path + "/data")
}

override protected def latestVersion: String = NaiveBayesModel.latestVersion
}

object NaiveBayesModel extends Importable[NaiveBayesModel] {
Expand All @@ -110,7 +112,7 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] {
case Row(clazz: String, version: String) =>
assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" +
s" was given model file with metadata specifying a different model class: $clazz")
assert(version == Exportable.latestVersion, // only 1 version exists currently
assert(version == latestVersion, // only 1 version exists currently
s"NaiveBayesModel.load did not recognize model format version: $version")
}

Expand All @@ -127,6 +129,8 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] {
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}

override protected def latestVersion: String = "1.0"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,8 @@ trait Exportable {
*/
def save(sc: SparkContext, path: String): Unit

}

private[mllib] object Exportable {

/** Current version of model import/export format. */
val latestVersion: String = "1.0"
protected def latestVersion: String

}

Expand All @@ -78,6 +74,9 @@ trait Importable[Model <: Exportable] {
*/
def load(sc: SparkContext, path: String): Model

/** Current version of model import/export format. */
protected def latestVersion: String

}

private[mllib] object Importable {
Expand Down

0 comments on commit c495dba

Please sign in to comment.