From 7d1877a1101ef3201d61c826ffd275063eb1b6f7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 May 2015 00:56:41 -0700 Subject: [PATCH 1/4] Made spark.ml prediction abstractions public. Organized their prediction methods for efficient computation of multiple output columns. --- .../ml/{impl/estimator => }/Predictor.scala | 50 ++------ .../spark/ml/classification/Classifier.scala | 110 ++++++------------ .../DecisionTreeClassifier.scala | 5 +- .../ml/classification/GBTClassifier.scala | 5 +- .../classification/LogisticRegression.scala | 100 ++++++---------- .../ProbabilisticClassifier.scala | 102 +++++++++++----- .../RandomForestClassifier.scala | 5 +- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 5 +- .../spark/ml/regression/GBTRegressor.scala | 5 +- .../ml/regression/LinearRegression.scala | 5 +- .../ml/regression/RandomForestRegressor.scala | 5 +- .../spark/ml/regression/Regressor.scala | 42 ++----- .../spark/ml/{impl => }/tree/treeParams.scala | 4 +- .../apache/spark/mllib/linalg/Vectors.scala | 20 ++++ 15 files changed, 204 insertions(+), 265 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/{impl/estimator => }/Predictor.scala (86%) rename mllib/src/main/scala/org/apache/spark/ml/{impl => }/tree/treeParams.scala (99%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala similarity index 86% rename from mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala rename to mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e8b3628140e99..0e53877de92db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -15,29 +15,23 @@ * limitations under the License. */ -package org.apache.spark.ml.impl.estimator +package org.apache.spark.ml -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} /** - * :: DeveloperApi :: - * - * Trait for parameters for prediction (regression and classification). - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + * (private[ml]) Trait for parameters for prediction (regression and classification). */ -@DeveloperApi -private[spark] trait PredictorParams extends Params +private[ml] trait PredictorParams extends Params with HasLabelCol with HasFeaturesCol with HasPredictionCol { /** @@ -63,7 +57,7 @@ private[spark] trait PredictorParams extends Params } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Abstraction for prediction problems (regression and classification). * @@ -73,11 +67,9 @@ private[spark] trait PredictorParams extends Params * parameter to specify the concrete type. * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type * parameter to specify the concrete type for the corresponding model. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] abstract class Predictor[ +@DeveloperApi +abstract class Predictor[ FeaturesType, Learner <: Predictor[FeaturesType, Learner, M], M <: PredictionModel[FeaturesType, M]] @@ -104,8 +96,6 @@ private[spark] abstract class Predictor[ } /** - * :: DeveloperApi :: - * * Train a model using the given dataset and parameters. * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation * and copying parameters into the model. @@ -113,12 +103,9 @@ private[spark] abstract class Predictor[ * @param dataset Training dataset * @return Fitted model */ - @DeveloperApi protected def train(dataset: DataFrame): M /** - * :: DeveloperApi :: - * * Returns the SQL DataType corresponding to the FeaturesType type parameter. * * This is used by [[validateAndTransformSchema()]]. @@ -126,7 +113,6 @@ private[spark] abstract class Predictor[ * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - @DeveloperApi protected def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { @@ -146,7 +132,7 @@ private[spark] abstract class Predictor[ } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Abstraction for a model for prediction tasks (regression and classification). * @@ -154,11 +140,9 @@ private[spark] abstract class Predictor[ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type * parameter to specify the concrete type for the corresponding model. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] +@DeveloperApi +abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] extends Model[M] with PredictorParams { /** @group setParam */ @@ -168,8 +152,6 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] /** - * :: DeveloperApi :: - * * Returns the SQL DataType corresponding to the FeaturesType type parameter. * * This is used by [[validateAndTransformSchema()]]. @@ -177,7 +159,6 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - @DeveloperApi protected def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { @@ -192,12 +173,8 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel * @return transformed dataset with [[predictionCol]] of type [[Double]] */ override def transform(dataset: DataFrame): DataFrame = { - // This default implementation should be overridden as needed. - - // Check schema transformSchema(dataset.schema, logging = true) - - if ($(predictionCol) != "") { + if ($(predictionCol).nonEmpty) { dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + @@ -207,11 +184,8 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel } /** - * :: DeveloperApi :: - * * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. */ - @DeveloperApi protected def predict(features: FeaturesType): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d3361e24705c8..db8232da67562 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -26,15 +26,12 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + /** - * :: DeveloperApi :: - * Params for classification. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + * (private[spark]) Params for classification. */ -@DeveloperApi -private[spark] trait ClassifierParams extends PredictorParams - with HasRawPredictionCol { +private[spark] trait ClassifierParams + extends PredictorParams with HasRawPredictionCol { override protected def validateAndTransformSchema( schema: StructType, @@ -46,23 +43,21 @@ private[spark] trait ClassifierParams extends PredictorParams } /** - * :: AlphaComponent :: + * :: DeveloperApi :: + * * Single-label binary or multiclass classification. * Classes are indexed {0, 1, ..., numClasses - 1}. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam E Concrete Estimator type * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] abstract class Classifier[ +@DeveloperApi +abstract class Classifier[ FeaturesType, E <: Classifier[FeaturesType, E, M], M <: ClassificationModel[FeaturesType, M]] - extends Predictor[FeaturesType, E, M] - with ClassifierParams { + extends Predictor[FeaturesType, E, M] with ClassifierParams { /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] @@ -71,17 +66,15 @@ private[spark] abstract class Classifier[ } /** - * :: AlphaComponent :: + * :: DeveloperApi :: + * * Model produced by a [[Classifier]]. * Classes are indexed {0, 1, ..., numClasses - 1}. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] +@DeveloperApi abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] extends PredictionModel[FeaturesType, M] with ClassifierParams { @@ -101,13 +94,27 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @return transformed dataset */ override def transform(dataset: DataFrame): DataFrame = { - // This default implementation should be overridden as needed. - - // Check schema transformSchema(dataset.schema, logging = true) - val (numColsOutput, outputData) = - ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this) + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = dataset + var numColsOutput = 0 + if (getRawPredictionCol != "") { + outputData = outputData.withColumn(getRawPredictionCol, + callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + numColsOutput += 1 + } + if (getPredictionCol != "") { + val predUDF = if (getRawPredictionCol != "") { + callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol)) + } else { + callUDF(predict _, DoubleType, col(getFeaturesCol)) + } + outputData = outputData.withColumn(getPredictionCol, predUDF) + numColsOutput += 1 + } + if (numColsOutput == 0) { logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") @@ -116,22 +123,17 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur } /** - * :: DeveloperApi :: - * * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from [[predictRaw()]]. */ - @DeveloperApi override protected def predict(features: FeaturesType): Double = { - predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 + raw2prediction(predictRaw(features)) } /** - * :: DeveloperApi :: - * * Raw prediction for each possible label. * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives * a measure of confidence in each possible label (where larger = more confident). @@ -141,48 +143,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * This raw prediction may be any real number, where a larger value indicates greater * confidence for that label. */ - @DeveloperApi protected def predictRaw(features: FeaturesType): Vector -} - -private[ml] object ClassificationModel { /** - * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] - * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. - * @param dataset Input dataset - * @return (number of columns added, transformed dataset) + * Given a vector of raw predictions, select the predicted label. + * This may be overridden to support thresholds which favor particular labels. + * @return predicted label */ - def transformColumnsImpl[FeaturesType]( - dataset: DataFrame, - model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = { - - // Output selected columns only. - // This is a bit complicated since it tries to avoid repeated computation. - var tmpData = dataset - var numColsOutput = 0 - if (model.getRawPredictionCol != "") { - // output raw prediction - val features2raw: FeaturesType => Vector = model.predictRaw - tmpData = tmpData.withColumn(model.getRawPredictionCol, - callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol))) - numColsOutput += 1 - if (model.getPredictionCol != "") { - val raw2pred: Vector => Double = (rawPred) => { - rawPred.toArray.zipWithIndex.maxBy(_._1)._2 - } - tmpData = tmpData.withColumn(model.getPredictionCol, - callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol))) - numColsOutput += 1 - } - } else if (model.getPredictionCol != "") { - // output prediction - val features2pred: FeaturesType => Double = model.predict - tmpData = tmpData.withColumn(model.getPredictionCol, - callUDF(features2pred, DoubleType, col(model.getFeaturesCol))) - numColsOutput += 1 - } - (numColsOutput, tmpData) - } - + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.findMax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 419e5ba05d38a..dcebea1d4b015 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 534ea95b1c538..ae51b05a0c42d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -21,11 +21,10 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index b73be035e29b5..e8c2e8212f2ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,9 +21,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel /** @@ -99,76 +98,17 @@ class LogisticRegressionModel private[ml] ( /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) + /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept } + /** Score (probability) for class label 1. For binary classification only. */ private val score: Vector => Double = (features) => { val m = margin(features) 1.0 / (1.0 + math.exp(-m)) } - override def transform(dataset: DataFrame): DataFrame = { - // This is overridden (a) to be more efficient (avoiding re-computing values when creating - // multiple output columns) and (b) to handle threshold, which the abstractions do not use. - // TODO: We should abstract away the steps defined by UDFs below so that the abstractions - // can call whichever UDFs are needed to create the output columns. - - // Check schema - transformSchema(dataset.schema, logging = true) - - // Output selected columns only. - // This is a bit complicated since it tries to avoid repeated computation. - // rawPrediction (-margin, margin) - // probability (1.0-score, score) - // prediction (max margin) - var tmpData = dataset - var numColsOutput = 0 - if ($(rawPredictionCol) != "") { - val features2raw: Vector => Vector = (features) => predictRaw(features) - tmpData = tmpData.withColumn($(rawPredictionCol), - callUDF(features2raw, new VectorUDT, col($(featuresCol)))) - numColsOutput += 1 - } - if ($(probabilityCol) != "") { - if ($(rawPredictionCol) != "") { - val raw2prob = udf { (rawPreds: Vector) => - val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - Vectors.dense(1.0 - prob1, prob1): Vector - } - tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol)))) - } else { - val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } - tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol)))) - } - numColsOutput += 1 - } - if ($(predictionCol) != "") { - val t = $(threshold) - if ($(probabilityCol) != "") { - val predict = udf { probs: Vector => - if (probs(1) > t) 1.0 else 0.0 - } - tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol)))) - } else if ($(rawPredictionCol) != "") { - val predict = udf { rawPreds: Vector => - val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - if (prob1 > t) 1.0 else 0.0 - } - tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol)))) - } else { - val predict = udf { features: Vector => this.predict(features) } - tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol)))) - } - numColsOutput += 1 - } - if (numColsOutput == 0) { - this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + - " since no output columns were set.") - } - tmpData - } - override val numClasses: Int = 2 /** @@ -179,17 +119,43 @@ class LogisticRegressionModel private[ml] ( if (score(features) > getThreshold) 1 else 0 } - override protected def predictProbabilities(features: Vector): Vector = { - val s = score(features) - Vectors.dense(1.0 - s, s) + override protected def raw2probabilitiesInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + while (i < dv.size) { + dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i))) + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in LogisticRegressionModel:" + + " raw2probabilitiesInPlace encountered SparseVector") + } } override protected def predictRaw(features: Vector): Vector = { val m = margin(features) - Vectors.dense(0.0, m) + Vectors.dense(-m, m) } override def copy(extra: ParamMap): LogisticRegressionModel = { copyValues(new LogisticRegressionModel(parent, weights, intercept), extra) } + + override protected def raw2prediction(rawPrediction: Vector): Double = { + val t = getThreshold + val rawThreshold = if (t == 0.0) { + Double.NegativeInfinity + } else if (t == 1.0) { + Double.PositiveInfinity + } else { + Math.log(t / (1.0 - t)) + } + if (rawPrediction(1) > rawThreshold) 1 else 0 + } + + override protected def probability2prediction(probability: Vector): Double = { + if (probability(1) > getThreshold) 1 else 0 + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 8519841c5c26c..3184795efa67f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** - * Params for probabilistic classification. + * (private[classification]) Params for probabilistic classification. */ private[classification] trait ProbabilisticClassifierParams extends ClassifierParams with HasProbabilityCol { @@ -42,17 +42,15 @@ private[classification] trait ProbabilisticClassifierParams /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Single-label binary or multiclass classifier which can output class conditional probabilities. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam E Concrete Estimator type * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent +@DeveloperApi private[spark] abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], @@ -65,17 +63,15 @@ private[spark] abstract class ProbabilisticClassifier[ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Model produced by a [[ProbabilisticClassifier]]. * Classes are indexed {0, 1, ..., numClasses - 1}. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent +@DeveloperApi private[spark] abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -95,39 +91,81 @@ private[spark] abstract class ProbabilisticClassificationModel[ * @return transformed dataset */ override def transform(dataset: DataFrame): DataFrame = { - // This default implementation should be overridden as needed. - - // Check schema transformSchema(dataset.schema, logging = true) - val (numColsOutput, outputData) = - ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this) - // Output selected columns only. - if ($(probabilityCol) != "") { - // output probabilities - outputData.withColumn($(probabilityCol), - callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol)))) - } else { - if (numColsOutput == 0) { - this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = dataset + var numColsOutput = 0 + if ($(rawPredictionCol).nonEmpty) { + outputData = outputData.withColumn(getRawPredictionCol, + callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + numColsOutput += 1 + } + if ($(probabilityCol).nonEmpty) { + val probUDF = if ($(rawPredictionCol).nonEmpty) { + // rawPrediction -> probability + callUDF(raw2probabilities _, new VectorUDT, col($(rawPredictionCol))) + } else { + // features -> probability + callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))) + } + outputData = outputData.withColumn($(probabilityCol), probUDF) + numColsOutput += 1 + } + if ($(predictionCol).nonEmpty) { + val predUDF = if ($(rawPredictionCol).nonEmpty) { + callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) + } else if ($(probabilityCol).nonEmpty) { + callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) + } else { + callUDF(predict _, DoubleType, col($(featuresCol))) } - outputData + outputData = outputData.withColumn($(predictionCol), predUDF) + numColsOutput += 1 + } + + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") } + outputData } /** - * :: DeveloperApi :: + * Estimate the probability of each class given the raw prediction, + * doing the computation in-place. + * These predictions are also called class conditional probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. * + * @return Estimated class conditional probabilities (modified input vector) + */ + protected def raw2probabilitiesInPlace(rawPrediction: Vector): Vector + + /** Non-in-place version of raw2probabilitiesInPlace */ + protected def raw2probabilities(rawPrediction: Vector): Vector = { + val probs = rawPrediction.copy + raw2probabilitiesInPlace(probs) + } + + /** * Predict the probability of each class given the features. * These predictions are also called class conditional probabilities. * - * WARNING: Not all models output well-calibrated probability estimates! These probabilities - * should be treated as confidences, not precise probabilities. - * * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + * + * @return Estimated class conditional probabilities + */ + protected def predictProbabilities(features: FeaturesType): Vector = { + val rawPreds = predictRaw(features) + raw2probabilitiesInPlace(rawPreds) + } + + /** + * Given a vector of class conditional probabilities, select the predicted label. + * This may be overridden to support thresholds which favor particular labels. + * @return predicted label */ - @DeveloperApi - protected def predictProbabilities(features: FeaturesType): Vector + protected def probability2prediction(probability: Vector): Double = probability.findMax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 17f59bb42e129..9954893f14359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -20,10 +20,9 @@ package org.apache.spark.ml.classification import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index d379172e0bf53..0e1ff97a8bf60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -40,8 +40,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", Some("\"rawPrediction\"")), - ParamDesc[String]("probabilityCol", - "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + + " probabilities. Note: Not all models output well-calibrated probability estimates!" + + " These probabilities should be treated as confidences, not precise probabilities.", + Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", isValid = "ParamValidators.inRange(0, 1)"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index b07c26fe79b36..f8f0b161a4812 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index bc796958e4545..461905c12701a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -21,10 +21,9 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 66c475f2d9840..e63c9a3eead52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -39,7 +40,7 @@ import org.apache.spark.util.StatCounter /** * Params for linear regression. */ -private[regression] trait LinearRegressionParams extends RegressorParams +private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** @@ -240,7 +241,7 @@ class LinearRegressionModel private[ml] ( * + \bar{y} / \hat{y}||^2 * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 * }}} - * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is + * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is * {{{ * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. * }}}, and diff is diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 0468a1be1ba74..dbc628927433d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams} +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c6b3327db6ad3..c72ef29680329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -17,62 +17,40 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -/** - * :: DeveloperApi :: - * Params for regression. - * Currently empty, but may add functionality later. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. - */ -@DeveloperApi -private[spark] trait RegressorParams extends PredictorParams /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Single-label regression * * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] * @tparam Learner Concrete Estimator type * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent +@DeveloperApi private[spark] abstract class Regressor[ FeaturesType, Learner <: Regressor[FeaturesType, Learner, M], M <: RegressionModel[FeaturesType, M]] - extends Predictor[FeaturesType, Learner, M] - with RegressorParams { + extends Predictor[FeaturesType, Learner, M] with PredictorParams { // TODO: defaultEvaluator (follow-up PR) } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Model produced by a [[Regressor]]. * * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] * @tparam M Concrete Model type. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] - extends PredictionModel[FeaturesType, M] with RegressorParams { - - /** - * :: DeveloperApi :: - * - * Predict real-valued label for the given features. - * This internal method is used to implement [[transform()]] and output [[predictionCol]]. - */ - @DeveloperApi - protected def predict(features: FeaturesType): Double +@DeveloperApi +abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with PredictorParams { + // TODO: defaultEvaluator (follow-up PR) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 0e225627d4ee3..816fcedf2efb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.ml.impl.tree +package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.impl.estimator.PredictorParams +import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 188d1e542b5b5..e20fe1c562316 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -150,6 +150,26 @@ sealed trait Vector extends Serializable { toDense } } + + /** + * Find the index of a maximal element. Returns the first maximal element in case of a tie. + * Returns -1 if vector has length 0. + */ + private[spark] def findMax: Int = { + if (size == 0) { + 0 + } else { + var maxIdx = 0 + var maxValue = apply(0) + foreachActive { (idx, value) => + if (value > maxValue) { + maxIdx = idx + maxValue = value + } + } + maxIdx + } + } } /** From 5cda84dece1fbfd6ba770123f8f2a693e87ef557 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 May 2015 10:25:59 -0700 Subject: [PATCH 2/4] regenerated sharedParams --- .../spark/ml/classification/ProbabilisticClassifier.scala | 2 -- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 3184795efa67f..34fe441b5cd00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -104,10 +104,8 @@ private[spark] abstract class ProbabilisticClassificationModel[ } if ($(probabilityCol).nonEmpty) { val probUDF = if ($(rawPredictionCol).nonEmpty) { - // rawPrediction -> probability callUDF(raw2probabilities _, new VectorUDT, col($(rawPredictionCol))) } else { - // features -> probability callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))) } outputData = outputData.withColumn($(probabilityCol), probUDF) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index fb1874ccfc8dc..87f86807c3c91 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -128,10 +128,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for column name for predicted class conditional probabilities. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") setDefault(probabilityCol, "probability") From 15b995777a544df087fc9eeed9f0d6962a0afc97 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 May 2015 10:30:06 -0700 Subject: [PATCH 3/4] renamed probabilities to probability in method names --- .../ml/classification/LogisticRegression.scala | 2 +- .../classification/ProbabilisticClassifier.scala | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e8c2e8212f2ac..550369d18cfec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -119,7 +119,7 @@ class LogisticRegressionModel private[ml] ( if (score(features) > getThreshold) 1 else 0 } - override protected def raw2probabilitiesInPlace(rawPrediction: Vector): Vector = { + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { case dv: DenseVector => var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 34fe441b5cd00..23e0efb1e7473 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -104,9 +104,9 @@ private[spark] abstract class ProbabilisticClassificationModel[ } if ($(probabilityCol).nonEmpty) { val probUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2probabilities _, new VectorUDT, col($(rawPredictionCol))) + callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol))) } else { - callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))) + callUDF(predictProbability _, new VectorUDT, col($(featuresCol))) } outputData = outputData.withColumn($(probabilityCol), probUDF) numColsOutput += 1 @@ -139,12 +139,12 @@ private[spark] abstract class ProbabilisticClassificationModel[ * * @return Estimated class conditional probabilities (modified input vector) */ - protected def raw2probabilitiesInPlace(rawPrediction: Vector): Vector + protected def raw2probabilityInPlace(rawPrediction: Vector): Vector - /** Non-in-place version of raw2probabilitiesInPlace */ - protected def raw2probabilities(rawPrediction: Vector): Vector = { + /** Non-in-place version of [[raw2probabilityInPlace()]] */ + protected def raw2probability(rawPrediction: Vector): Vector = { val probs = rawPrediction.copy - raw2probabilitiesInPlace(probs) + raw2probabilityInPlace(probs) } /** @@ -155,9 +155,9 @@ private[spark] abstract class ProbabilisticClassificationModel[ * * @return Estimated class conditional probabilities */ - protected def predictProbabilities(features: FeaturesType): Vector = { + protected def predictProbability(features: FeaturesType): Vector = { val rawPreds = predictRaw(features) - raw2probabilitiesInPlace(rawPreds) + raw2probabilityInPlace(rawPreds) } /** From e9aa0eafc04e1f0a80e088dbe2b897433b65827d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 May 2015 20:31:52 -0700 Subject: [PATCH 4/4] moved findMax to DenseVector and renamed to argmax. fixed bug for vector of length 0 --- .../spark/ml/classification/Classifier.scala | 2 +- .../ProbabilisticClassifier.scala | 2 +- .../apache/spark/mllib/linalg/Vectors.scala | 42 ++++++++++--------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index db8232da67562..263d580fe2dd3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -150,5 +150,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.findMax + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 23e0efb1e7473..330ae2938f4e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -165,5 +165,5 @@ private[spark] abstract class ProbabilisticClassificationModel[ * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.findMax + protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index e20fe1c562316..f6bcdf83cd337 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -150,26 +150,6 @@ sealed trait Vector extends Serializable { toDense } } - - /** - * Find the index of a maximal element. Returns the first maximal element in case of a tie. - * Returns -1 if vector has length 0. - */ - private[spark] def findMax: Int = { - if (size == 0) { - 0 - } else { - var maxIdx = 0 - var maxValue = apply(0) - foreachActive { (idx, value) => - if (value > maxValue) { - maxIdx = idx - maxValue = value - } - } - maxIdx - } - } } /** @@ -607,6 +587,28 @@ class DenseVector(val values: Array[Double]) extends Vector { } new SparseVector(size, ii, vv) } + + /** + * Find the index of a maximal element. Returns the first maximal element in case of a tie. + * Returns -1 if vector has length 0. + */ + private[spark] def argmax: Int = { + if (size == 0) { + -1 + } else { + var maxIdx = 0 + var maxValue = values(0) + var i = 1 + while (i < size) { + if (values(i) > maxValue) { + maxIdx = i + maxValue = values(i) + } + i += 1 + } + maxIdx + } + } } object DenseVector {