Skip to content

Commit

Permalink
Added weight column for regression evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Dec 11, 2018
1 parent 0a37da6 commit aca6255
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
Expand Down Expand Up @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "rmse")

@Since("2.0.0")
Expand All @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(false, predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,64 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs
* @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
* or (prediction, observation) pairs
* @param throughOrigin True if the regression is through the origin. For example, in linear
* regression, it will be true without fitting intercept.
*/
@Since("1.2.0")
class RegressionMetrics @Since("2.0.0") (
predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
throughOrigin: Boolean, predAndObsWithOptWeight: RDD[_])
extends Logging {

@Since("1.2.0")
def this(predictionAndObservations: RDD[(Double, Double)]) =
this(predictionAndObservations, false)
this(false, predictionAndObservations)

/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs
* @param throughOrigin True if the regression is through the origin. For example, in linear
* regression, it will be true without fitting intercept.
*/
@Since("2.0.0")
def this(predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) =
this(throughOrigin, predictionAndObservations)

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndObservations a DataFrame with two double columns:
* prediction and observation
*/
private[mllib] def this(predictionAndObservations: DataFrame) =
this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
this(false, predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))

/**
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
*/
private lazy val summary: MultivariateStatisticalSummary = {
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
case (prediction: Double, observation: Double, weight: Double) =>
(Vectors.dense(observation, observation - prediction), weight)
case (prediction: Double, observation: Double) =>
(Vectors.dense(observation, observation - prediction), 1.0)
}.treeAggregate(new MultivariateOnlineSummarizer())(
(summary, v) => summary.add(v),
(summary, sample) => summary.add(sample._1, sample._2),
(sum1, sum2) => sum1.merge(sum2)
)
summary
}

private lazy val SSy = math.pow(summary.normL2(0), 2)
private lazy val SSerr = math.pow(summary.normL2(1), 2)
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
predictionAndObservations.map {
case (prediction, _) => math.pow(prediction - yMean, 2)
predAndObsWithOptWeight.map {
case (prediction: Double, _: Double, weight: Double) =>
math.pow(prediction - yMean, 2) * weight
case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
}.sum()
}

Expand All @@ -79,7 +96,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def explainedVariance: Double = {
SSreg / summary.count
SSreg / summary.weightSum
}

/**
Expand All @@ -88,7 +105,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanAbsoluteError: Double = {
summary.normL1(1) / summary.count
summary.normL1(1) / summary.weightSum
}

/**
Expand All @@ -97,7 +114,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanSquaredError: Double = {
SSerr / summary.count
SSerr / summary.weightSum
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var weightSum: Array[Double] = _
private var currWeightSum: Array[Double] = _
private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
Expand All @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
weightSum = Array.ofDim[Double](n)
currWeightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
Expand All @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
val localWeightSum = weightSum
val localWeightSum = currWeightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
Expand Down Expand Up @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = weightSum(i)
val otherNnz = other.weightSum(i)
val thisNnz = currWeightSum(i)
val otherNnz = other.currWeightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
Expand All @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
weightSum(i) = totalNnz
currWeightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
Expand All @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
this.weightSum = other.weightSum.clone()
this.currWeightSum = other.currWeightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
Expand All @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
Expand All @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val len = currM2n.length
while (i < len) {
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand All @@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
@Since("1.1.0")
override def count: Long = totalCnt

/**
* Sum of weights.
*/
override def weightSum: Double = totalWeightSum

/**
* Number of nonzero elements in each dimension.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary {
@Since("1.0.0")
def count: Long

/**
* Sum of weights.
*/
@Since("2.2.0")
def weightSum: Double

/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
}

test("regression metrics with same (1.0) weight samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight)
assert(metrics.explainedVariance ~== 8.79687 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
}

/**
* The following values are hand calculated using the formula:
* [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
* preds = c(2.25, -0.25, 1.75, 7.75)
* obs = c(3.0, -0.5, 2.0, 7.0)
* weights = c(0.1, 0.2, 0.15, 0.05)
* count = 4
*
* Weighted metrics can be calculated with MultivariateStatisticalSummary.
* (observations, observations - predictions)
* mean (1.7, 0.05)
* variance (7.3, 0.3)
* numNonZeros (0.5, 0.5)
* max (7.0, 0.75)
* min (-0.5, -0.75)
* normL2 (2.0, 0.32596)
* normL1 (1.05, 0.2)
*
* explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425
* meanAbsoluteError: normL1(1) / weightedCount = 0.4
* meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
* rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
* r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
*/
test("regression metrics with weighted samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight)
assert(metrics.explainedVariance ~== 5.2425 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
}
}
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),

// [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
) ++ Seq(
// [SPARK-17019] Expose on-heap and off-heap memory usage in various places
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),
Expand Down

0 comments on commit aca6255

Please sign in to comment.