Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1406] Mllib pmml model export #3062

Closed
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
a0e3679
export and pmml export traits
selvinsource Oct 12, 2014
226e184
added javadoc and export model type in case there is a need to support
selvinsource Oct 18, 2014
9bc494f
added scala suite tests
selvinsource Oct 26, 2014
8e71b8d
kmeans pmml export implementation
selvinsource Oct 27, 2014
1433b11
complete suite tests
selvinsource Oct 29, 2014
8841439
adjust scala style in order to compile
selvinsource Oct 29, 2014
07a29bf
Update LICENSE
selvinsource Oct 29, 2014
f75b988
Merge remote-tracking branch 'origin/master' into mllib_pmml_model_ex…
selvinsource Oct 29, 2014
cd6c07c
fixed scala style to run tests
selvinsource Oct 29, 2014
aba5ee1
fixed cluster export
selvinsource Oct 30, 2014
e1eb251
removed serialization part, this will be part of the ModelExporter
selvinsource Nov 5, 2014
6357b98
set it to private
selvinsource Nov 5, 2014
c3ef9b8
set it to private
selvinsource Nov 5, 2014
349a76b
new helper object to serialize the models to pmml format
selvinsource Nov 5, 2014
834ca44
reordered the import accordingly to the guidelines
selvinsource Nov 5, 2014
a1b4dc3
updated imports
selvinsource Nov 5, 2014
df8a89e
added pmml version to pmml model
selvinsource Nov 5, 2014
ae8b993
updated some commented tests to use the new ModelExporter object
selvinsource Nov 5, 2014
e29dfb9
removed version, by default is set to 4.2 (latest from jpmml)
selvinsource Nov 5, 2014
78515ec
[SPARK-1406] added pmml export for LinearRegressionModel,
selvinsource Nov 28, 2014
c67ce81
Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_…
selvinsource Nov 29, 2014
3ae8ae5
[SPARK-1406] Adjusted imported order according to the guidelines
selvinsource Nov 29, 2014
1faf985
[SPARK-1406] Added target field to the regression model for completeness
selvinsource Nov 29, 2014
19adf29
[SPARK-1406] Fixed scala style
selvinsource Nov 29, 2014
82f2131
Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_…
selvinsource Dec 7, 2014
da2ec11
[SPARK-1406] added linear SVM PMML export
selvinsource Dec 8, 2014
03bc3a5
added logistic regression
selvinsource Dec 9, 2014
8fe12bb
[SPARK-1406] Adjusted logistic regression export description and target
selvinsource Dec 13, 2014
d559ec5
Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_…
selvinsource Feb 8, 2015
7b33b4e
[SPARK-1406] Added a PMMLExportable interface
selvinsource Feb 8, 2015
f46c75c
[SPARK-1406] Added PMMLExportable to supported models
selvinsource Feb 8, 2015
7a949d0
[SPARK-1406] Fixed scala style
selvinsource Feb 8, 2015
b25bbf7
[SPARK-1406] Added export of pmml to distributed file system using the
selvinsource Mar 1, 2015
b8823b0
Merge remote-tracking branch 'upstream/master' into
selvinsource Apr 19, 2015
e2ffae8
fixed scala style
selvinsource Apr 19, 2015
1676e15
fixed scala issue
selvinsource Apr 19, 2015
472d757
fix code style
mengxr Apr 20, 2015
e2313df
Merge pull request #1 from mengxr/SPARK-1406
selvinsource Apr 21, 2015
3c22f79
more code style
mengxr Apr 21, 2015
a0a55f7
Merge pull request #2 from mengxr/SPARK-1406
selvinsource Apr 21, 2015
66b7c12
[SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible
selvinsource Apr 21, 2015
dea98ca
[SPARK-1406] Exclude transitive dependency for pmml model
selvinsource Apr 21, 2015
25dce33
[SPARK-1406] Update code to latest pmml model
selvinsource Apr 21, 2015
cfcb596
[SPARK-1406] Throw IllegalArgumentException when exporting a multinomial
selvinsource Apr 25, 2015
7a5e0ec
[SPARK-1406] Binary classification for SVM and Logistic Regression
selvinsource Apr 28, 2015
30165c4
[SPARK-1406] Fixed extreme cases for logit
selvinsource Apr 28, 2015
085cf42
[SPARK-1406] Added Double Min and Max
selvinsource Apr 29, 2015
852aac6
[SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file
selvinsource Apr 29, 2015
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ BSD-style licenses
The following components are provided under a BSD-style license. See project link for details.

(BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.7 - https://github.com/jpmml/jpmml-model)
(BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
Expand Down
15 changes: 15 additions & 0 deletions mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
<version>1.1.15</version>
<exclusions>
<exclusion>
<groupId>com.sun.xml.fastinfoset</groupId>
<artifactId>FastInfoset</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.istack</groupId>
<artifactId>istack-commons-runtime</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<profiles>
<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
Expand All @@ -46,7 +47,7 @@ class LogisticRegressionModel (
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

if (numClasses == 2) {
require(weights.size == numFeatures,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD
Expand All @@ -36,7 +37,7 @@ class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

private var threshold: Option[Double] = Some(0.0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
Expand All @@ -34,7 +35,8 @@ import org.apache.spark.sql.Row
/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
class KMeansModel (
val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable {

/** A Java-friendly constructor that takes an Iterable of Vectors. */
def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.pmml

import java.io.{File, OutputStream, StringWriter}
import javax.xml.transform.stream.StreamResult

import org.jpmml.model.JAXBUtil

import org.apache.spark.SparkContext
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory

/**
* Export model to the PMML format
* Predictive Model Markup Language (PMML) is an XML-based file format
* developed by the Data Mining Group (www.dmg.org).
*/
trait PMMLExportable {

/**
* Export the model to the stream result in PMML format
*/
private def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
}

/**
* Export the model to a local file in PMML format
*/
def toPMML(localPath: String): Unit = {
toPMML(new StreamResult(new File(localPath)))
}

/**
* Export the model to a directory on a distributed file system in PMML format
*/
def toPMML(sc: SparkContext, path: String): Unit = {
val pmml = toPMML()
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
}

/**
* Export the model to the OutputStream in PMML format
*/
def toPMML(outputStream: OutputStream): Unit = {
toPMML(new StreamResult(outputStream))
}

/**
* Export the model to a String in PMML format
*/
def toPMML(): String = {
val writer = new StringWriter
toPMML(new StreamResult(writer))
writer.toString
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class BinaryClassificationPMMLModelExport(
model : GeneralizedLinearModel,
description : String,
normalizationMethod : RegressionNormalizationMethodType,
threshold: Double)
extends PMMLModelExport {

populateBinaryClassificationPMML()

/**
* Export the input LogisticRegressionModel or SVMModel to PMML format.
*/
private def populateBinaryClassificationPMML(): Unit = {
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
var interceptNO = threshold
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
if (threshold <= 0)
interceptNO = -1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be Double.MinValue.

else if (threshold >= 1)
interceptNO = 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double.MaxValue

else
interceptNO = -math.log(1/threshold -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space around / and after -: math.log(1 / threshold - 1)

}
val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.CLASSIFICATION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withNormalizationMethod(normalizationMethod)
.withRegressionTables(regressionTableYES, regressionTableNO)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// add target field
val targetField = FieldName.create("target")
dataDictionary
.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.pmml.export

import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
* PMML Model Export for GeneralizedLinearModel abstract class
*/
private[mllib] class GeneralizedLinearPMMLModelExport(
model: GeneralizedLinearModel,
description: String)
extends PMMLModelExport {

populateGeneralizedLinearPMML(model)

/**
* Export the input GeneralizedLinearModel model to PMML format.
*/
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.REGRESSION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withRegressionTables(regressionTable)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// for completeness add target field
val targetField = FieldName.create("target")
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
}
}
Loading