Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator #17085

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
Expand Down Expand Up @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "rmse")

@Since("2.0.0")
Expand All @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs
* @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
* or (prediction, observation) pairs
* @param throughOrigin True if the regression is through the origin. For example, in linear
* regression, it will be true without fitting intercept.
*/
@Since("1.2.0")
class RegressionMetrics @Since("2.0.0") (
predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
extends Logging {

@Since("1.2.0")
def this(predictionAndObservations: RDD[(Double, Double)]) =
def this(predictionAndObservations: RDD[_ <: Product]) =
this(predictionAndObservations, false)

/**
Expand All @@ -52,22 +53,27 @@ class RegressionMetrics @Since("2.0.0") (
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
*/
private lazy val summary: MultivariateStatisticalSummary = {
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
case (prediction: Double, observation: Double, weight: Double) =>
(Vectors.dense(observation, observation - prediction), weight)
case (prediction: Double, observation: Double) =>
(Vectors.dense(observation, observation - prediction), 1.0)
}.treeAggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(summary, sample) => summary.add(sample._1, sample._2),
(sum1, sum2) => sum1.merge(sum2)
)
summary
}

private lazy val SSy = math.pow(summary.normL2(0), 2)
private lazy val SSerr = math.pow(summary.normL2(1), 2)
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
predictionAndObservations.map {
case (prediction, _) => math.pow(prediction - yMean, 2)
predAndObsWithOptWeight.map {
case (prediction: Double, _: Double, weight: Double) =>
math.pow(prediction - yMean, 2) * weight
case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
}.sum()
}

Expand All @@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def explainedVariance: Double = {
SSreg / summary.count
SSreg / summary.weightSum
}

/**
Expand All @@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanAbsoluteError: Double = {
summary.normL1(1) / summary.count
summary.normL1(1) / summary.weightSum
}

/**
Expand All @@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanSquaredError: Double = {
SSerr / summary.count
SSerr / summary.weightSum
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var weightSum: Array[Double] = _
private var currWeightSum: Array[Double] = _
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think the rename was necessary, but it is OK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Copy link
Contributor Author

@imatiach-msft imatiach-msft Dec 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, it looks like the build failed because the private variable conflicts with the public variable that was defined:

/**

  • Sum of weights.
    */
    override def weightSum: Double = totalWeightSum

I think this may be the best name for the public variable so I would prefer to keep it. The private variable now follows the naming convention of the other private array variables so I think this makes sense.

private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
Expand All @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
weightSum = Array.ofDim[Double](n)
currWeightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
Expand All @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
val localWeightSum = weightSum
val localWeightSum = currWeightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
Expand Down Expand Up @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = weightSum(i)
val otherNnz = other.weightSum(i)
val thisNnz = currWeightSum(i)
val otherNnz = other.currWeightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
Expand All @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
weightSum(i) = totalNnz
currWeightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
Expand All @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
this.weightSum = other.weightSum.clone()
this.currWeightSum = other.currWeightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
Expand All @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
Expand All @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val len = currM2n.length
while (i < len) {
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand All @@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
@Since("1.1.0")
override def count: Long = totalCnt

/**
* Sum of weights.
*/
override def weightSum: Double = totalWeightSum

/**
* Number of nonzero elements in each dimension.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary {
@Since("1.0.0")
def count: Long

/**
* Sum of weights.
*/
@Since("3.0.0")
def weightSum: Double

/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
}

test("regression metrics with same (1.0) weight samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
assert(metrics.explainedVariance ~== 8.79687 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
}

/**
* The following values are hand calculated using the formula:
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
* preds = c(2.25, -0.25, 1.75, 7.75)
* obs = c(3.0, -0.5, 2.0, 7.0)
* weights = c(0.1, 0.2, 0.15, 0.05)
* count = 4
*
* Weighted metrics can be calculated with MultivariateStatisticalSummary.
* (observations, observations - predictions)
* mean (1.7, 0.05)
* variance (7.3, 0.3)
* numNonZeros (0.5, 0.5)
* max (7.0, 0.75)
* min (-0.5, -0.75)
* normL2 (2.0, 0.32596)
* normL1 (1.05, 0.2)
*
* explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425
* meanAbsoluteError: normL1(1) / weightedCount = 0.4
* meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
* r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
*/
test("regression metrics with weighted samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
assert(metrics.explainedVariance ~== 5.2425 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
}
}
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),

// [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
) ++ Seq(
// [SPARK-17019] Expose on-heap and off-heap memory usage in various places
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),
Expand Down