diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala index 7e2e76f53988c..d7b1efc19dedb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala @@ -25,16 +25,18 @@ object ModelExportFactory { /** * Factory object to help creating the necessary ModelExport implementation - * taking as input the ModelExportType (for example PMML) and the machine learning model (for example KMeansModel). + * taking as input the ModelExportType (for example PMML) + * and the machine learning model (for example KMeansModel). */ def createModelExport(model: Any, exportType: ModelExportType): ModelExport = { - return exportType match{ - case PMML => model match{ - case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans) - case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass) - } - case _ => throw new IllegalArgumentException("Export type not supported:" + exportType) - } + return exportType match{ + case PMML => model match{ + case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans) + case _ => + throw new IllegalArgumentException("Export not supported for model: " + model.getClass) + } + case _ => throw new IllegalArgumentException("Export type not supported:" + exportType) + } } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala index 1e940a6aa5e50..60bc3eabb9144 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.export /** * Defines export types. - * - PMML exports the machine learning models in an XML-based file format called Predictive Model Markup Language developed by the Data Mining Group (www.dmg.org). + * - PMML exports the machine learning models in an XML-based file format + * called Predictive Model Markup Language developed by the Data Mining Group (www.dmg.org). */ object ModelExportType extends Enumeration{ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala index 2f0af9a18f470..37d7b6bf71734 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala @@ -62,30 +62,42 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ var miningSchema = new MiningSchema() for ( i <- 0 to (clusterCenter.size - 1)) { - fields(i) = FieldName.create("field_"+i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) - } + fields(i) = FieldName.create("field_" + i) + dataDictionary + .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + } var comparisonMeasure = new ComparisonMeasure() - .withKind(Kind.DISTANCE) - .withMeasure(new SquaredEuclidean() + .withKind(Kind.DISTANCE) + .withMeasure(new SquaredEuclidean() ); dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()); - + pmml.setDataDictionary(dataDictionary); - var clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length) - .withModelName("k-means"); + var clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, + MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length) + .withModelName("k-means"); for ( i <- 0 to (clusterCenter.size - 1)) { - clusteringModel.withClusteringFields(new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) - var cluster = new Cluster().withName("cluster_"+i).withArray(new org.dmg.pmml.Array().withType(Type.REAL).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString(" "))) - //cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue) - clusteringModel.withClusters(cluster) + clusteringModel.withClusteringFields( + new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF) + ) + var cluster = new Cluster() + .withName("cluster_" + i) + .withArray(new org.dmg.pmml.Array() + .withType(Type.REAL) + .withN(clusterCenter.size) + .withValue(model.clusterCenters(i).toArray.mkString(" "))) + //we don't have the size of the single cluster but only the centroids (withValue) + //.withSize(value) + clusteringModel.withClusters(cluster) } - + pmml.withModels(clusteringModel); } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala index c1e84f62f9223..f18f9cee8ea05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala @@ -41,13 +41,14 @@ trait PMMLModelExport extends ModelExport{ private def setHeader(pmml : PMML): Unit = { var version = getClass().getPackage().getImplementationVersion() - var app = new Application().withName("Apache Spark MLlib").withVersion(version) - var timestamp = new Timestamp().withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) - var header = new Header() - .withCopyright("www.dmg.org") + var app = new Application().withName("Apache Spark MLlib").withVersion(version) + var timestamp = new Timestamp() + .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + var header = new Header() + .withCopyright("www.dmg.org") .withApplication(app) .withTimestamp(timestamp); - pmml.setHeader(header); + pmml.setHeader(header); } /**