From cfcb5960cee46c2a84b69c3f3a766d78e38f6c42 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Sat, 25 Apr 2015 08:40:51 +0100 Subject: [PATCH] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression --- .../mllib/pmml/export/PMMLModelExportFactory.scala | 6 +++++- .../pmml/export/PMMLModelExportFactorySuite.scala | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index 0c374a46fb562..bd8c8f96a6e55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -44,7 +44,11 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(svm, "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") case logistic: LogisticRegressionModel => - new LogisticRegressionPMMLModelExport(logistic, "logistic regression") + if(logistic.numClasses == 2) + new LogisticRegressionPMMLModelExport(logistic, "logistic regression") + else + throw new IllegalArgumentException( + "PMML Export not supported for Multinomial Logistic Regression") case _ => throw new IllegalArgumentException( "PMML Export not supported for model: " + model.getClass.getName) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index a94854e4c0f20..b87e96e7032f3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -73,6 +73,18 @@ class PMMLModelExportFactorySuite extends FunSuite { assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) } + + test("PMMLModelExportFactory throw IllegalArgumentException " + + "when passing a Multinomial Logistic Regression") { + /** 3 classes, 2 features */ + val multiclassLogisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + numFeatures = 2, numClasses = 3) + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) + } + } test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { val invalidModel = new Object