From 8e71b8d8cc3922e834fe0f21e4f05e773241dad7 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Mon, 27 Oct 2014 21:58:54 +0000 Subject: [PATCH] kmeans pmml export implementation --- .../export/pmml/KMeansPMMLModelExport.scala | 63 ++++++++++++++++++- .../mllib/export/pmml/PMMLModelExport.scala | 19 +++++- .../export/ModelExportFactorySuite.scala | 2 +- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala index 99ab256adfd0b..2f0af9a18f470 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala @@ -18,6 +18,24 @@ package org.apache.spark.mllib.export.pmml import org.apache.spark.mllib.clustering.KMeansModel +import org.dmg.pmml.DataDictionary +import org.dmg.pmml.FieldName +import org.dmg.pmml.DataField +import org.dmg.pmml.OpType +import org.dmg.pmml.DataType +import org.dmg.pmml.MiningSchema +import org.dmg.pmml.MiningField +import org.dmg.pmml.FieldUsageType +import org.dmg.pmml.ComparisonMeasure +import org.dmg.pmml.ComparisonMeasure.Kind +import org.dmg.pmml.SquaredEuclidean +import org.dmg.pmml.ClusteringModel +import org.dmg.pmml.MiningFunctionType +import org.dmg.pmml.ClusteringModel.ModelClass +import org.dmg.pmml.ClusteringField +import org.dmg.pmml.CompareFunctionType +import org.dmg.pmml.Cluster +import org.dmg.pmml.Array.Type /** * PMML Model Export for KMeansModel class @@ -30,9 +48,48 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ populateKMeansPMML(model); private def populateKMeansPMML(model : KMeansModel): Unit = { - //TODO: set here header description - pmml.setVersion("testing... kmeans..."); - //TODO: generate the model... + + pmml.getHeader().setDescription("k-means clustering"); + + if(model.clusterCenters.length > 0){ + + val clusterCenter = model.clusterCenters(0) + + var fields = new Array[FieldName](clusterCenter.size) + + var dataDictionary = new DataDictionary() + + var miningSchema = new MiningSchema() + + for ( i <- 0 to (clusterCenter.size - 1)) { + fields(i) = FieldName.create("field_"+i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) + } + + var comparisonMeasure = new ComparisonMeasure() + .withKind(Kind.DISTANCE) + .withMeasure(new SquaredEuclidean() + ); + + dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()); + + pmml.setDataDictionary(dataDictionary); + + var clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length) + .withModelName("k-means"); + + for ( i <- 0 to (clusterCenter.size - 1)) { + clusteringModel.withClusteringFields(new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + var cluster = new Cluster().withName("cluster_"+i).withArray(new org.dmg.pmml.Array().withType(Type.REAL).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString(" "))) + //cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue) + clusteringModel.withClusters(cluster) + } + + pmml.withModels(clusteringModel); + + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala index 6d8e8ff0797f6..c1e84f62f9223 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala @@ -23,6 +23,11 @@ import org.jpmml.model.JAXBUtil import org.dmg.pmml.PMML import javax.xml.transform.stream.StreamResult import scala.beans.BeanProperty +import org.dmg.pmml.Application +import org.dmg.pmml.Timestamp +import org.dmg.pmml.Header +import java.text.SimpleDateFormat +import java.util.Date trait PMMLModelExport extends ModelExport{ @@ -31,7 +36,19 @@ trait PMMLModelExport extends ModelExport{ */ @BeanProperty var pmml: PMML = new PMML(); - //TODO: set here header app copyright and timestamp + + setHeader(pmml); + + private def setHeader(pmml : PMML): Unit = { + var version = getClass().getPackage().getImplementationVersion() + var app = new Application().withName("Apache Spark MLlib").withVersion(version) + var timestamp = new Timestamp().withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + var header = new Header() + .withCopyright("www.dmg.org") + .withApplication(app) + .withTimestamp(timestamp); + pmml.setHeader(header); + } /** * Write the exported model (in PMML XML) to the output stream specified diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala index fc627fcb75584..9b6b4160d6120 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala @@ -40,7 +40,7 @@ class ModelExportFactorySuite extends FunSuite{ } - test("ModelExportFactory generate IllegalArgumentException when passing an unsupported model") { + test("ModelExportFactory throws IllegalArgumentException when passing an unsupported model") { val invalidModel = new Object;