diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 093aa391dfaa8..1b23576d59bd4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -144,7 +144,7 @@ class LogisticRegressionModel ( // Create JSON metadata. val metadata = LogisticRegressionModel.Metadata( - clazz = this.getClass.getName, version = Exportable.latestVersion) + clazz = this.getClass.getName, version = latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") @@ -153,6 +153,9 @@ class LogisticRegressionModel ( val dataRDD: DataFrame = sc.parallelize(Seq(data)) dataRDD.saveAsParquetFile(path + "/data") } + + override protected def latestVersion: String = LogisticRegressionModel.latestVersion + } object LogisticRegressionModel extends Importable[LogisticRegressionModel] { @@ -176,7 +179,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { case Row(clazz: String, version: String) => assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" + s" was given model file with metadata specifying a different model class: $clazz") - assert(version == Exportable.latestVersion, // only 1 version exists currently + assert(version == latestVersion, // only 1 version exists currently s"LogisticRegressionModel.load did not recognize model format version: $version") } @@ -199,6 +202,8 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { lr } + override protected def latestVersion: String = "1.0" + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index fa46b64e80fbf..704bda6b55784 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -78,7 +78,7 @@ class NaiveBayesModel private[mllib] ( // Create JSON metadata. val metadata = NaiveBayesModel.Metadata( - clazz = this.getClass.getName, version = Exportable.latestVersion) + clazz = this.getClass.getName, version = latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") @@ -87,6 +87,8 @@ class NaiveBayesModel private[mllib] ( val dataRDD: DataFrame = sc.parallelize(Seq(data)) dataRDD.saveAsParquetFile(path + "/data") } + + override protected def latestVersion: String = NaiveBayesModel.latestVersion } object NaiveBayesModel extends Importable[NaiveBayesModel] { @@ -110,7 +112,7 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { case Row(clazz: String, version: String) => assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" + s" was given model file with metadata specifying a different model class: $clazz") - assert(version == Exportable.latestVersion, // only 1 version exists currently + assert(version == latestVersion, // only 1 version exists currently s"NaiveBayesModel.load did not recognize model format version: $version") } @@ -127,6 +129,8 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray new NaiveBayesModel(labels, pi, theta) } + + override protected def latestVersion: String = "1.0" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index 06cd822afff5f..66490ca5ad5c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -49,12 +49,8 @@ trait Exportable { */ def save(sc: SparkContext, path: String): Unit -} - -private[mllib] object Exportable { - /** Current version of model import/export format. */ - val latestVersion: String = "1.0" + protected def latestVersion: String } @@ -78,6 +74,9 @@ trait Importable[Model <: Exportable] { */ def load(sc: SparkContext, path: String): Model + /** Current version of model import/export format. */ + protected def latestVersion: String + } private[mllib] object Importable {