Skip to content

Commit

Permalink
[SPARK-1406] Added target field to the regression model for completeness
Browse files Browse the repository at this point in the history
Adjusted unit test to deal with this change
  • Loading branch information
selvinsource committed Nov 29, 2014
1 parent 3ae8ae5 commit 1faf985
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

//for completeness add target field
val targetField = FieldName.create("target");
dataDictionary
.withDataFields(
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
)
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())

pmml.setDataDictionary(dataDictionary)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "linear regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size)
assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
Expand All @@ -58,7 +58,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "ridge regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size)
assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
Expand All @@ -71,7 +71,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "lasso regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size)
assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
Expand Down

0 comments on commit 1faf985

Please sign in to comment.