Skip to content

Commit

Permalink
[SPARK-1406] added linear SVM PMML export
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Dec 8, 2014
1 parent 82f2131 commit da2ec11
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.export

import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.export.ModelExportType.ModelExportType
import org.apache.spark.mllib.export.ModelExportType.PMML
Expand Down Expand Up @@ -44,6 +45,8 @@ private[mllib] object ModelExportFactory {
new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression")
case lassoRegression: LassoModel =>
new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression")
case svm: SVMModel =>
new GeneralizedLinearPMMLModelExport(svm, "linear SVM")
case _ =>
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.export

import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LassoModel
Expand Down Expand Up @@ -48,15 +49,16 @@ class ModelExportFactorySuite extends FunSuite{

}

test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a"
+"LinearRegressionModel, RidgeRegressionModel or LassoModel") {
test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
+"LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {

//arrange
val linearInput = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 1, 17)
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);

//act
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
Expand All @@ -73,6 +75,11 @@ class ModelExportFactorySuite extends FunSuite{
//assert
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

//act
val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML)
//assert
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

}

test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.export.pmml
import org.dmg.pmml.RegressionModel
import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.export.ModelExportFactory
import org.apache.spark.mllib.export.ModelExportType
import org.apache.spark.mllib.regression.LassoModel
Expand All @@ -37,6 +38,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);

//act by exporting the model to the PMML format
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
Expand Down Expand Up @@ -76,11 +78,25 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size)

//act
val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML)
//assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "linear SVM")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === svmModel.weights.size)

//manual checking
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
//ModelExporter.toPMML(svmModel,"/tmp/svm.xml")

}

Expand Down

0 comments on commit da2ec11

Please sign in to comment.