Skip to content

Commit

Permalink
[SPARK-1406] Added PMMLExportable to supported models
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Feb 8, 2015
1 parent 7b33b4e commit f46c75c
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
Expand All @@ -46,7 +47,7 @@ class LogisticRegressionModel (
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

if (numClasses == 2) {
require(weights.size == numFeatures,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
Expand All @@ -37,7 +38,7 @@ class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

private var threshold: Option[Double] = Some(0.0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.pmml.PMMLExportable

/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable with PMMLExportable {

/** Total number of clusters. */
def k: Int = clusterCenters.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
Expand All @@ -34,7 +35,7 @@ class LassoModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable with Saveable {
with RegressionModel with Serializable with Saveable with PMMLExportable {

override protected def predictPoint(
dataMatrix: Vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
Expand All @@ -34,7 +35,7 @@ class LinearRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
with Saveable {
with Saveable with PMMLExportable {

override protected def predictPoint(
dataMatrix: Vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
Expand All @@ -35,7 +36,7 @@ class RidgeRegressionModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable with Saveable {
with RegressionModel with Serializable with Saveable with PMMLExportable {

override protected def predictPoint(
dataMatrix: Vector,
Expand Down

0 comments on commit f46c75c

Please sign in to comment.