Skip to content

Commit

Permalink
kmeans pmml export implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Oct 27, 2014
1 parent 9bc494f commit 8e71b8d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@
package org.apache.spark.mllib.export.pmml

import org.apache.spark.mllib.clustering.KMeansModel
import org.dmg.pmml.DataDictionary
import org.dmg.pmml.FieldName
import org.dmg.pmml.DataField
import org.dmg.pmml.OpType
import org.dmg.pmml.DataType
import org.dmg.pmml.MiningSchema
import org.dmg.pmml.MiningField
import org.dmg.pmml.FieldUsageType
import org.dmg.pmml.ComparisonMeasure
import org.dmg.pmml.ComparisonMeasure.Kind
import org.dmg.pmml.SquaredEuclidean
import org.dmg.pmml.ClusteringModel
import org.dmg.pmml.MiningFunctionType
import org.dmg.pmml.ClusteringModel.ModelClass
import org.dmg.pmml.ClusteringField
import org.dmg.pmml.CompareFunctionType
import org.dmg.pmml.Cluster
import org.dmg.pmml.Array.Type

/**
* PMML Model Export for KMeansModel class
Expand All @@ -30,9 +48,48 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
populateKMeansPMML(model);

private def populateKMeansPMML(model : KMeansModel): Unit = {
//TODO: set here header description
pmml.setVersion("testing... kmeans...");
//TODO: generate the model...

pmml.getHeader().setDescription("k-means clustering");

if(model.clusterCenters.length > 0){

val clusterCenter = model.clusterCenters(0)

var fields = new Array[FieldName](clusterCenter.size)

var dataDictionary = new DataDictionary()

var miningSchema = new MiningSchema()

for ( i <- 0 to (clusterCenter.size - 1)) {
fields(i) = FieldName.create("field_"+i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE))
}

var comparisonMeasure = new ComparisonMeasure()
.withKind(Kind.DISTANCE)
.withMeasure(new SquaredEuclidean()
);

dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size());

pmml.setDataDictionary(dataDictionary);

var clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure, MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length)
.withModelName("k-means");

for ( i <- 0 to (clusterCenter.size - 1)) {
clusteringModel.withClusteringFields(new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF))
var cluster = new Cluster().withName("cluster_"+i).withArray(new org.dmg.pmml.Array().withType(Type.REAL).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString(" ")))
//cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue)
clusteringModel.withClusters(cluster)
}

pmml.withModels(clusteringModel);

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ import org.jpmml.model.JAXBUtil
import org.dmg.pmml.PMML
import javax.xml.transform.stream.StreamResult
import scala.beans.BeanProperty
import org.dmg.pmml.Application
import org.dmg.pmml.Timestamp
import org.dmg.pmml.Header
import java.text.SimpleDateFormat
import java.util.Date

trait PMMLModelExport extends ModelExport{

Expand All @@ -31,7 +36,19 @@ trait PMMLModelExport extends ModelExport{
*/
@BeanProperty
var pmml: PMML = new PMML();
//TODO: set here header app copyright and timestamp

setHeader(pmml);

private def setHeader(pmml : PMML): Unit = {
var version = getClass().getPackage().getImplementationVersion()
var app = new Application().withName("Apache Spark MLlib").withVersion(version)
var timestamp = new Timestamp().withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date()))
var header = new Header()
.withCopyright("www.dmg.org")
.withApplication(app)
.withTimestamp(timestamp);
pmml.setHeader(header);
}

/**
* Write the exported model (in PMML XML) to the output stream specified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ModelExportFactorySuite extends FunSuite{

}

test("ModelExportFactory generate IllegalArgumentException when passing an unsupported model") {
test("ModelExportFactory throws IllegalArgumentException when passing an unsupported model") {

val invalidModel = new Object;

Expand Down

0 comments on commit 8e71b8d

Please sign in to comment.