Skip to content

Commit

Permalink
[SPARK-5995] [ML] Make Prediction dev API public
Browse files Browse the repository at this point in the history
Changes:
* Update protected prediction methods, following design doc. **<--most interesting change**
* Changed abstract classes for Estimator and Model to be public.  Added DeveloperApi tag.  (I kept the traits for Estimator/Model Params private.)
* Changed ProbabilisticClassificationModel method names to use probability instead of probabilities.

CC: mengxr shivaram etrain

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5913 from jkbradley/public-dev-api and squashes the following commits:

e9aa0ea [Joseph K. Bradley] moved findMax to DenseVector and renamed to argmax. fixed bug for vector of length 0
15b9957 [Joseph K. Bradley] renamed probabilities to probability in method names
5cda84d [Joseph K. Bradley] regenerated sharedParams
7d1877a [Joseph K. Bradley] Made spark.ml prediction abstractions public.  Organized their prediction methods for efficient computation of multiple output columns.

(cherry picked from commit 1ad04da)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
jkbradley authored and mengxr committed May 6, 2015
1 parent 14bcb84 commit b681b93
Show file tree
Hide file tree
Showing 16 changed files with 206 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,23 @@
* limitations under the License.
*/

package org.apache.spark.ml.impl.estimator
package org.apache.spark.ml

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Row}

/**
* :: DeveloperApi ::
*
* Trait for parameters for prediction (regression and classification).
*
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
* (private[ml]) Trait for parameters for prediction (regression and classification).
*/
@DeveloperApi
private[spark] trait PredictorParams extends Params
private[ml] trait PredictorParams extends Params
with HasLabelCol with HasFeaturesCol with HasPredictionCol {

/**
Expand All @@ -63,7 +57,7 @@ private[spark] trait PredictorParams extends Params
}

/**
* :: AlphaComponent ::
* :: DeveloperApi ::
*
* Abstraction for prediction problems (regression and classification).
*
Expand All @@ -73,11 +67,9 @@ private[spark] trait PredictorParams extends Params
* parameter to specify the concrete type.
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
* parameter to specify the concrete type for the corresponding model.
*
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
private[spark] abstract class Predictor[
@DeveloperApi
abstract class Predictor[
FeaturesType,
Learner <: Predictor[FeaturesType, Learner, M],
M <: PredictionModel[FeaturesType, M]]
Expand All @@ -104,29 +96,23 @@ private[spark] abstract class Predictor[
}

/**
* :: DeveloperApi ::
*
* Train a model using the given dataset and parameters.
* Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
* and copying parameters into the model.
*
* @param dataset Training dataset
* @return Fitted model
*/
@DeveloperApi
protected def train(dataset: DataFrame): M

/**
* :: DeveloperApi ::
*
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
*
* This is used by [[validateAndTransformSchema()]].
* This workaround is needed since SQL has different APIs for Scala and Java.
*
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT

override def transformSchema(schema: StructType): StructType = {
Expand All @@ -146,19 +132,17 @@ private[spark] abstract class Predictor[
}

/**
* :: AlphaComponent ::
* :: DeveloperApi ::
*
* Abstraction for a model for prediction tasks (regression and classification).
*
* @tparam FeaturesType Type of features.
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
* @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
* parameter to specify the concrete type for the corresponding model.
*
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
@DeveloperApi
abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
extends Model[M] with PredictorParams {

/** @group setParam */
Expand All @@ -168,16 +152,13 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]

/**
* :: DeveloperApi ::
*
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
*
* This is used by [[validateAndTransformSchema()]].
* This workaround is needed since SQL has different APIs for Scala and Java.
*
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT

override def transformSchema(schema: StructType): StructType = {
Expand All @@ -192,12 +173,8 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
* @return transformed dataset with [[predictionCol]] of type [[Double]]
*/
override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.

// Check schema
transformSchema(dataset.schema, logging = true)

if ($(predictionCol) != "") {
if ($(predictionCol).nonEmpty) {
dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
Expand All @@ -207,11 +184,8 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
}

/**
* :: DeveloperApi ::
*
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
*/
@DeveloperApi
protected def predict(features: FeaturesType): Double
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,21 @@

package org.apache.spark.ml.classification

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}


/**
* :: DeveloperApi ::
* Params for classification.
*
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
* (private[spark]) Params for classification.
*/
@DeveloperApi
private[spark] trait ClassifierParams extends PredictorParams
with HasRawPredictionCol {
private[spark] trait ClassifierParams
extends PredictorParams with HasRawPredictionCol {

override protected def validateAndTransformSchema(
schema: StructType,
Expand All @@ -46,23 +43,21 @@ private[spark] trait ClassifierParams extends PredictorParams
}

/**
* :: AlphaComponent ::
* :: DeveloperApi ::
*
* Single-label binary or multiclass classification.
* Classes are indexed {0, 1, ..., numClasses - 1}.
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam E Concrete Estimator type
* @tparam M Concrete Model type
*
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
private[spark] abstract class Classifier[
@DeveloperApi
abstract class Classifier[
FeaturesType,
E <: Classifier[FeaturesType, E, M],
M <: ClassificationModel[FeaturesType, M]]
extends Predictor[FeaturesType, E, M]
with ClassifierParams {
extends Predictor[FeaturesType, E, M] with ClassifierParams {

/** @group setParam */
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
Expand All @@ -71,17 +66,15 @@ private[spark] abstract class Classifier[
}

/**
* :: AlphaComponent ::
* :: DeveloperApi ::
*
* Model produced by a [[Classifier]].
* Classes are indexed {0, 1, ..., numClasses - 1}.
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam M Concrete Model type
*
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
@AlphaComponent
private[spark]
@DeveloperApi
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with ClassifierParams {

Expand All @@ -101,13 +94,27 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return transformed dataset
*/
override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.

// Check schema
transformSchema(dataset.schema, logging = true)

val (numColsOutput, outputData) =
ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
outputData = outputData.withColumn(getRawPredictionCol,
callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
numColsOutput += 1
}
if (getPredictionCol != "") {
val predUDF = if (getRawPredictionCol != "") {
callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
} else {
callUDF(predict _, DoubleType, col(getFeaturesCol))
}
outputData = outputData.withColumn(getPredictionCol, predUDF)
numColsOutput += 1
}

if (numColsOutput == 0) {
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
Expand All @@ -116,22 +123,17 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
}

/**
* :: DeveloperApi ::
*
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
*
* This default implementation for classification predicts the index of the maximum value
* from [[predictRaw()]].
*/
@DeveloperApi
override protected def predict(features: FeaturesType): Double = {
predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2
raw2prediction(predictRaw(features))
}

/**
* :: DeveloperApi ::
*
* Raw prediction for each possible label.
* The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
* a measure of confidence in each possible label (where larger = more confident).
Expand All @@ -141,48 +143,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* This raw prediction may be any real number, where a larger value indicates greater
* confidence for that label.
*/
@DeveloperApi
protected def predictRaw(features: FeaturesType): Vector
}

private[ml] object ClassificationModel {

/**
* Added prediction column(s). This is separated from [[ClassificationModel.transform()]]
* since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
* @param dataset Input dataset
* @return (number of columns added, transformed dataset)
* Given a vector of raw predictions, select the predicted label.
* This may be overridden to support thresholds which favor particular labels.
* @return predicted label
*/
def transformColumnsImpl[FeaturesType](
dataset: DataFrame,
model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = {

// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var tmpData = dataset
var numColsOutput = 0
if (model.getRawPredictionCol != "") {
// output raw prediction
val features2raw: FeaturesType => Vector = model.predictRaw
tmpData = tmpData.withColumn(model.getRawPredictionCol,
callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol)))
numColsOutput += 1
if (model.getPredictionCol != "") {
val raw2pred: Vector => Double = (rawPred) => {
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
}
tmpData = tmpData.withColumn(model.getPredictionCol,
callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol)))
numColsOutput += 1
}
} else if (model.getPredictionCol != "") {
// output prediction
val features2pred: FeaturesType => Double = model.predict
tmpData = tmpData.withColumn(model.getPredictionCol,
callUDF(features2pred, DoubleType, col(model.getFeaturesCol)))
numColsOutput += 1
}
(numColsOutput, tmpData)
}

protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down
Loading

0 comments on commit b681b93

Please sign in to comment.