diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala index 4889be8e3c7ec..618fe79a7b14a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala @@ -52,10 +52,7 @@ private[mllib] object ModelExportFactory { svm, "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") case logisticRegression: LogisticRegressionModel => - new LogisticRegressionPMMLModelExport( - logisticRegression, - "logistic regression: if predicted value > 0.5, " - + "the outcome is positive, or negative otherwise") + new LogisticRegressionPMMLModelExport(logisticRegression, "logistic regression") case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala index f0c6708af58c4..0d65bc9ddc627 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala @@ -58,10 +58,10 @@ private[mllib] class LogisticRegressionPMMLModelExport( val miningSchema = new MiningSchema() val regressionTableYES = new RegressionTable(model.intercept) - .withTargetCategory("YES") + .withTargetCategory("1") val regressionTableNO = new RegressionTable(0.0) - .withTargetCategory("NO") + .withTargetCategory("0") val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION) .withModelName(description) @@ -83,7 +83,7 @@ private[mllib] class LogisticRegressionPMMLModelExport( val targetField = FieldName.create("target"); dataDictionary .withDataFields( - new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE) + new DataField(targetField, OpType.CATEGORICAL, DataType.STRING) ) miningSchema .withMiningFields(new MiningField(targetField) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala index 27093938102ba..0bb6c9a60a485 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala @@ -39,18 +39,18 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{ //assert that the PMML format is as expected assert(logisticModelExport.isInstanceOf[PMMLModelExport]) var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "logistic regression: if predicted value > 0.5, the outcome is positive, or negative otherwise") + assert(pmml.getHeader().getDescription() === "logistic regression") //check that the number of fields match the weights size assert(pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1) //this verify that there is a model attached to the pmml object and the model is a regression one - //it also verifies that the pmml model has a regression table (for target category YES) with the same number of predictors of the model weights + //it also verifies that the pmml model has a regression table (for target category 1) with the same number of predictors of the model weights assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(0).getTargetCategory() === "YES") + .getRegressionTables().get(0).getTargetCategory() === "1") assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] .getRegressionTables().get(0).getNumericPredictors().size() === logisticRegressionModel.weights.size) - //verify if there is a second table with target category NO and no predictors + //verify if there is a second table with target category 0 and no predictors assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] - .getRegressionTables().get(1).getTargetCategory() === "NO") + .getRegressionTables().get(1).getTargetCategory() === "0") assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] .getRegressionTables().get(1).getNumericPredictors().size() === 0)