Skip to content

Commit

Permalink
[SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - add…
Browse files Browse the repository at this point in the history
…ed weight column for regression evaluator

## What changes were proposed in this pull request?

The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data.

I've closed the PR: #16557
 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update.

The updates to the regression metrics were based on (and updated with new changes based on comments):
https://issues.apache.org/jira/browse/SPARK-11520
 ("RegressionMetrics should support instance weights")
 but the pull request was closed as the changes were never checked in.

## How was this patch tested?

I added tests to the metrics class.

Closes #17085 from imatiach-msft/ilmat/regression-evaluate.

Authored-by: Ilya Matiach <ilmat@microsoft.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
  • Loading branch information
imatiach-msft authored and srowen committed Dec 12, 2018
1 parent 79e36e2 commit 570b8f3
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
Expand Down Expand Up @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "rmse")

@Since("2.0.0")
Expand All @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs
* @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
* or (prediction, observation) pairs
* @param throughOrigin True if the regression is through the origin. For example, in linear
* regression, it will be true without fitting intercept.
*/
@Since("1.2.0")
class RegressionMetrics @Since("2.0.0") (
predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
extends Logging {

@Since("1.2.0")
def this(predictionAndObservations: RDD[(Double, Double)]) =
def this(predictionAndObservations: RDD[_ <: Product]) =
this(predictionAndObservations, false)

/**
Expand All @@ -52,22 +53,27 @@ class RegressionMetrics @Since("2.0.0") (
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
*/
private lazy val summary: MultivariateStatisticalSummary = {
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
case (prediction: Double, observation: Double, weight: Double) =>
(Vectors.dense(observation, observation - prediction), weight)
case (prediction: Double, observation: Double) =>
(Vectors.dense(observation, observation - prediction), 1.0)
}.treeAggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(summary, sample) => summary.add(sample._1, sample._2),
(sum1, sum2) => sum1.merge(sum2)
)
summary
}

private lazy val SSy = math.pow(summary.normL2(0), 2)
private lazy val SSerr = math.pow(summary.normL2(1), 2)
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
predictionAndObservations.map {
case (prediction, _) => math.pow(prediction - yMean, 2)
predAndObsWithOptWeight.map {
case (prediction: Double, _: Double, weight: Double) =>
math.pow(prediction - yMean, 2) * weight
case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
}.sum()
}

Expand All @@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def explainedVariance: Double = {
SSreg / summary.count
SSreg / summary.weightSum
}

/**
Expand All @@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanAbsoluteError: Double = {
summary.normL1(1) / summary.count
summary.normL1(1) / summary.weightSum
}

/**
Expand All @@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanSquaredError: Double = {
SSerr / summary.count
SSerr / summary.weightSum
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var weightSum: Array[Double] = _
private var currWeightSum: Array[Double] = _
private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
Expand All @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
weightSum = Array.ofDim[Double](n)
currWeightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
Expand All @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
val localWeightSum = weightSum
val localWeightSum = currWeightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
Expand Down Expand Up @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = weightSum(i)
val otherNnz = other.weightSum(i)
val thisNnz = currWeightSum(i)
val otherNnz = other.currWeightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
Expand All @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
weightSum(i) = totalNnz
currWeightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
Expand All @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
this.weightSum = other.weightSum.clone()
this.currWeightSum = other.currWeightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
Expand All @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
Expand All @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val len = currM2n.length
while (i < len) {
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand All @@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
@Since("1.1.0")
override def count: Long = totalCnt

/**
* Sum of weights.
*/
override def weightSum: Double = totalWeightSum

/**
* Number of nonzero elements in each dimension.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary {
@Since("1.0.0")
def count: Long

/**
* Sum of weights.
*/
@Since("3.0.0")
def weightSum: Double

/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
}

test("regression metrics with same (1.0) weight samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
assert(metrics.explainedVariance ~== 8.79687 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
}

/**
* The following values are hand calculated using the formula:
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
* preds = c(2.25, -0.25, 1.75, 7.75)
* obs = c(3.0, -0.5, 2.0, 7.0)
* weights = c(0.1, 0.2, 0.15, 0.05)
* count = 4
*
* Weighted metrics can be calculated with MultivariateStatisticalSummary.
* (observations, observations - predictions)
* mean (1.7, 0.05)
* variance (7.3, 0.3)
* numNonZeros (0.5, 0.5)
* max (7.0, 0.75)
* min (-0.5, -0.75)
* normL2 (2.0, 0.32596)
* normL1 (1.05, 0.2)
*
* explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425
* meanAbsoluteError: normL1(1) / weightedCount = 0.4
* meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
* r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
*/
test("regression metrics with weighted samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
assert(metrics.explainedVariance ~== 5.2425 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
}
}
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),

// [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
) ++ Seq(
// [SPARK-17019] Expose on-heap and off-heap memory usage in various places
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),
Expand Down

0 comments on commit 570b8f3

Please sign in to comment.