From aca62557fe394d500bd084ad840f9c0ff352cde3 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 27 Feb 2017 13:20:44 -0500 Subject: [PATCH] Added weight column for regression evaluator --- .../ml/evaluation/RegressionEvaluator.scala | 19 ++++--- .../mllib/evaluation/RegressionMetrics.scala | 43 +++++++++++----- .../stat/MultivariateOnlineSummarizer.scala | 25 ++++++---- .../stat/MultivariateStatisticalSummary.scala | 6 +++ .../evaluation/RegressionMetricsSuite.scala | 50 +++++++++++++++++++ project/MimaExcludes.scala | 5 +- 6 files changed, 118 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 031cd0d635bf4..7589eedb8ce61 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("2.2.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "rmse") @Since("2.0.0") @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = dataset - .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) + val predictionAndLabelsWithWeights = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) .rdd - .map { case Row(prediction: Double, label: Double) => (prediction, label) } - val metrics = new RegressionMetrics(predictionAndLabels) + .map { case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) } + val metrics = new RegressionMetrics(false, predictionAndLabelsWithWeights) val metric = $(metricName) match { case "rmse" => metrics.rootMeanSquaredError case "mse" => metrics.meanSquaredError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 020676cac5a64..36dd032c15919 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,18 +27,30 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight) + * or (prediction, observation) pairs * @param throughOrigin True if the regression is through the origin. For example, in linear * regression, it will be true without fitting intercept. */ @Since("1.2.0") class RegressionMetrics @Since("2.0.0") ( - predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + throughOrigin: Boolean, predAndObsWithOptWeight: RDD[_]) extends Logging { @Since("1.2.0") def this(predictionAndObservations: RDD[(Double, Double)]) = - this(predictionAndObservations, false) + this(false, predictionAndObservations) + + /** + * Evaluator for regression. + * + * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param throughOrigin True if the regression is through the origin. For example, in linear + * regression, it will be true without fitting intercept. + */ + @Since("2.0.0") + def this(predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) = + this(throughOrigin, predictionAndObservations) /** * An auxiliary constructor taking a DataFrame. @@ -46,16 +58,19 @@ class RegressionMetrics @Since("2.0.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(false, predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. */ private lazy val summary: MultivariateStatisticalSummary = { - val summary: MultivariateStatisticalSummary = predictionAndObservations.map { - case (prediction, observation) => Vectors.dense(observation, observation - prediction) + val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map { + case (prediction: Double, observation: Double, weight: Double) => + (Vectors.dense(observation, observation - prediction), weight) + case (prediction: Double, observation: Double) => + (Vectors.dense(observation, observation - prediction), 1.0) }.treeAggregate(new MultivariateOnlineSummarizer())( - (summary, v) => summary.add(v), + (summary, sample) => summary.add(sample._1, sample._2), (sum1, sum2) => sum1.merge(sum2) ) summary @@ -63,11 +78,13 @@ class RegressionMetrics @Since("2.0.0") ( private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) - private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SStot = summary.variance(0) * (summary.weightSum - 1) private lazy val SSreg = { val yMean = summary.mean(0) - predictionAndObservations.map { - case (prediction, _) => math.pow(prediction - yMean, 2) + predAndObsWithOptWeight.map { + case (prediction: Double, _: Double, weight: Double) => + math.pow(prediction - yMean, 2) * weight + case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2) }.sum() } @@ -79,7 +96,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def explainedVariance: Double = { - SSreg / summary.count + SSreg / summary.weightSum } /** @@ -88,7 +105,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanAbsoluteError: Double = { - summary.normL1(1) / summary.count + summary.normL1(1) / summary.weightSum } /** @@ -97,7 +114,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanSquaredError: Double = { - SSerr / summary.count + SSerr / summary.weightSum } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 0554b6d8ff5b5..6d510e1633d67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var totalCnt: Long = 0 private var totalWeightSum: Double = 0.0 private var weightSquareSum: Double = 0.0 - private var weightSum: Array[Double] = _ + private var currWeightSum: Array[Double] = _ private var nnz: Array[Long] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n = Array.ofDim[Double](n) currM2 = Array.ofDim[Double](n) currL1 = Array.ofDim[Double](n) - weightSum = Array.ofDim[Double](n) + currWeightSum = Array.ofDim[Double](n) nnz = Array.ofDim[Long](n) currMax = Array.fill[Double](n)(Double.MinValue) currMin = Array.fill[Double](n)(Double.MaxValue) @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 - val localWeightSum = weightSum + val localWeightSum = currWeightSum val localNumNonzeros = nnz val localCurrMax = currMax val localCurrMin = currMin @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { - val thisNnz = weightSum(i) - val otherNnz = other.weightSum(i) + val thisNnz = currWeightSum(i) + val otherNnz = other.currWeightSum(i) val totalNnz = thisNnz + otherNnz val totalCnnz = nnz(i) + other.nnz(i) if (totalNnz != 0.0) { @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMax(i) = math.max(currMax(i), other.currMax(i)) currMin(i) = math.min(currMin(i), other.currMin(i)) } - weightSum(i) = totalNnz + currWeightSum(i) = totalNnz nnz(i) = totalCnnz i += 1 } @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.totalCnt = other.totalCnt this.totalWeightSum = other.totalWeightSum this.weightSquareSum = other.weightSquareSum - this.weightSum = other.weightSum.clone() + this.currWeightSum = other.currWeightSum.clone() this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum) + realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum) i += 1 } Vectors.dense(realMean) @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val len = currM2n.length while (i < len) { // We prevent variance from negative value caused by numerical error. - realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) * + (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } @@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S @Since("1.1.0") override def count: Long = totalCnt + /** + * Sum of weights. + */ + override def weightSum: Double = totalWeightSum + /** * Number of nonzero elements in each dimension. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 39a16fb743d64..978842fdb27e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary { @Since("1.0.0") def count: Long + /** + * Sum of weights. + */ + @Since("2.2.0") + def weightSum: Double + /** * Number of nonzero elements (including explicitly presented zero values) in each column. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index f1d517383643d..6de76f333b862 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { "root mean squared error mismatch") assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } + + test("regression metrics with same (1.0) weight samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2) + val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight) + assert(metrics.explainedVariance ~== 8.79687 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch") + } + + /** + * The following values are hand calculated using the formula: + * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + * preds = c(2.25, -0.25, 1.75, 7.75) + * obs = c(3.0, -0.5, 2.0, 7.0) + * weights = c(0.1, 0.2, 0.15, 0.05) + * count = 4 + * + * Weighted metrics can be calculated with MultivariateStatisticalSummary. + * (observations, observations - predictions) + * mean (1.7, 0.05) + * variance (7.3, 0.3) + * numNonZeros (0.5, 0.5) + * max (7.0, 0.75) + * min (-0.5, -0.75) + * normL2 (2.0, 0.32596) + * normL1 (1.05, 0.2) + * + * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425 + * meanAbsoluteError: normL1(1) / weightedCount = 0.4 + * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125 + * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098 + * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910 + */ + test("regression metrics with weighted samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2) + val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight) + assert(metrics.explainedVariance ~== 5.2425 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b3252d70a80c8..883913332ca1e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -531,7 +531,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"), + + // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum") ) ++ Seq( // [SPARK-17019] Expose on-heap and off-heap memory usage in various places ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),