Skip to content

Commit

Permalink
updated based on similar previous PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Dec 11, 2018
1 parent aca6255 commit f708edb
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("2.2.0")
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "rmse")
Expand All @@ -88,7 +88,7 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
.rdd
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(false, predictionAndLabelsWithWeights)
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,20 @@ import org.apache.spark.sql.DataFrame
*/
@Since("1.2.0")
class RegressionMetrics @Since("2.0.0") (
throughOrigin: Boolean, predAndObsWithOptWeight: RDD[_])
predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
extends Logging {

@Since("1.2.0")
def this(predictionAndObservations: RDD[(Double, Double)]) =
this(false, predictionAndObservations)

/**
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs
* @param throughOrigin True if the regression is through the origin. For example, in linear
* regression, it will be true without fitting intercept.
*/
@Since("2.0.0")
def this(predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) =
this(throughOrigin, predictionAndObservations)
def this(predictionAndObservations: RDD[_ <: Product]) =
this(predictionAndObservations, false)

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndObservations a DataFrame with two double columns:
* prediction and observation
*/
private[mllib] def this(predictionAndObservations: DataFrame) =
this(false, predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))

/**
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var currWeightSum: Array[Double] = _
private var weightSum: Array[Double] = _
private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
Expand All @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
currWeightSum = Array.ofDim[Double](n)
weightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
Expand All @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
val localWeightSum = currWeightSum
val localWeightSum = weightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
Expand Down Expand Up @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = currWeightSum(i)
val otherNnz = other.currWeightSum(i)
val thisNnz = weightSum(i)
val otherNnz = other.weightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
Expand All @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
currWeightSum(i) = totalNnz
weightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
Expand All @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
this.currWeightSum = other.currWeightSum.clone()
this.weightSum = other.weightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
Expand All @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
Expand All @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val len = currM2n.length
while (i < len) {
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait MultivariateStatisticalSummary {
/**
* Sum of weights.
*/
@Since("2.2.0")
@Since("3.0.0")
def weightSum: Double

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("regression metrics with same (1.0) weight samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
assert(metrics.explainedVariance ~== 8.79687 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
Expand Down Expand Up @@ -174,7 +174,7 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("regression metrics with weighted samples") {
val predictionAndObservationWithWeight = sc.parallelize(
Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight)
val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
assert(metrics.explainedVariance ~== 5.2425 absTol eps,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
Expand Down

0 comments on commit f708edb

Please sign in to comment.