From da2ec11fdf79567a0c6ee175f09b7ef7e6c35398 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Mon, 8 Dec 2014 22:10:49 +0000 Subject: [PATCH] [SPARK-1406] added linear SVM PMML export --- .../spark/mllib/export/ModelExportFactory.scala | 3 +++ .../mllib/export/ModelExportFactorySuite.scala | 11 +++++++++-- .../GeneralizedLinearPMMLModelExportSuite.scala | 16 ++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala index f52d22093739d..282a32ebc5ced 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.export +import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.export.ModelExportType.ModelExportType import org.apache.spark.mllib.export.ModelExportType.PMML @@ -44,6 +45,8 @@ private[mllib] object ModelExportFactory { new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression") case lassoRegression: LassoModel => new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression") + case svm: SVMModel => + new GeneralizedLinearPMMLModelExport(svm, "linear SVM") case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala index b2208dd6d0c7d..6792e2d674bb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.export import org.scalatest.FunSuite +import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LassoModel @@ -48,8 +49,8 @@ class ModelExportFactorySuite extends FunSuite{ } - test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a" - +"LinearRegressionModel, RidgeRegressionModel or LassoModel") { + test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + +"LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") { //arrange val linearInput = LinearDataGenerator.generateLinearInput( @@ -57,6 +58,7 @@ class ModelExportFactorySuite extends FunSuite{ val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label); val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label); val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label); + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label); //act val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) @@ -73,6 +75,11 @@ class ModelExportFactorySuite extends FunSuite{ //assert assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + //act + val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML) + //assert + assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + } test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala index 628c74e418a28..402a84c2c8a47 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.export.pmml import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite +import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.export.ModelExportFactory import org.apache.spark.mllib.export.ModelExportType import org.apache.spark.mllib.regression.LassoModel @@ -37,6 +38,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label); val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label); val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label); + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label); //act by exporting the model to the PMML format val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) @@ -76,11 +78,25 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ //it also verifies that the pmml model has a regression table with the same number of predictors of the model weights assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] .getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size) + + //act + val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML) + //assert that the PMML format is as expected + assert(svmModelExport.isInstanceOf[PMMLModelExport]) + pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml() + assert(pmml.getHeader().getDescription() === "linear SVM") + //check that the number of fields match the weights size + assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1) + //this verify that there is a model attached to the pmml object and the model is a regression one + //it also verifies that the pmml model has a regression table with the same number of predictors of the model weights + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(0).getNumericPredictors().size() === svmModel.weights.size) //manual checking //ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml") //ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml") //ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml") + //ModelExporter.toPMML(svmModel,"/tmp/svm.xml") }