diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 685e6e27f0ba9..ca236d0ad5127 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -40,19 +40,32 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] lazy val precAtK: RDD[Array[Double]] = predictionAndLabels.map {case (pred, lab)=> val labSet = lab.toSet val n = pred.length - val topkPrec = Array.fill[Double](n)(0.0) + val topKPrec = Array.fill[Double](n)(0.0) var (i, cnt) = (0, 0) while (i < n) { if (labSet.contains(pred(i))) { cnt += 1 } - topkPrec(i) = cnt.toDouble / (i + 1) + topKPrec(i) = cnt.toDouble / (i + 1) i += 1 } - topkPrec + topKPrec } + /** + * @param k the position to compute the truncated precision + * @return the average precision at the first k ranking positions + */ + def precision(k: Int): Double = precAtK.map {topKPrec => + val n = topKPrec.length + if (k <= n) { + topKPrec(k - 1) + } else { + topKPrec(n - 1) * n / k + } + }.mean + /** * Returns the average precision for each query */ @@ -79,24 +92,34 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] /** * Returns the normalized discounted cumulative gain for each query */ - lazy val ndcg: RDD[Double] = predictionAndLabels.map {case (pred, lab) => + lazy val ndcgAtK: RDD[Array[Double]] = predictionAndLabels.map {case (pred, lab) => val labSet = lab.toSet - val n = math.min(pred.length, labSet.size) + val labSetSize = labSet.size + val n = math.max(pred.length, labSetSize) + val topKNdcg = Array.fill[Double](n)(0.0) var (maxDcg, dcg, i) = (0.0, 0.0, 0) while (i < n) { /* Calculate 1/log2(i + 2) */ - val gain = 1.0 / (math.log(i + 2) / math.log(2)) + val gain = math.log(2) / math.log(i + 2) if (labSet.contains(pred(i))) { dcg += gain } - maxDcg += gain + if (i < labSetSize) { + maxDcg += gain + } + topKNdcg(i) = dcg / maxDcg i += 1 } - dcg / maxDcg + topKNdcg } /** - * Returns the mean NDCG of all the queries + * @param k the position to compute the truncated ndcg + * @return the average ndcg at the first k ranking positions */ - lazy val meanNdcg: Double = ndcg.mean + def ndcg(k: Int): Double = ndcgAtK.map {topKNdcg => + val pos = math.min(k, topKNdcg.length) - 1 + topKNdcg(pos) + }.mean + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 1c9c8d7210288..986be6f66cd33 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -34,16 +34,33 @@ class RankingMetricsSuite extends FunSuite with LocalSparkContext { val precAtK = metrics.precAtK.collect() val avePrec = metrics.avePrec.collect() val map = metrics.meanAvePrec - val ndcg = metrics.ndcg.collect() - val aveNdcg = metrics.meanNdcg + val ndcgAtK = metrics.ndcgAtK.collect() assert(precAtK(0)(4) ~== 0.4 absTol eps) assert(precAtK(1)(6) ~== 3.0/7 absTol eps) + assert(precAtK(0)(2) ~== 2.0/3 absTol eps) + assert(precAtK(1)(2) ~== 1.0/3 absTol eps) + assert(precAtK(0)(9) ~== 0.5 absTol eps) + assert(precAtK(1)(9) ~== 0.3 absTol eps) + + assert(metrics.precision(1) ~== 0.5 absTol eps) + assert(metrics.precision(2) ~== 0.5 absTol eps) + assert(metrics.precision(3) ~== 0.5 absTol eps) + assert(metrics.precision(4) ~== 0.375 absTol eps) + assert(metrics.precision(10) ~== 0.4 absTol eps) + assert(avePrec(0) ~== 0.622222 absTol eps) assert(avePrec(1) ~== 0.442857 absTol eps) + assert(map ~== 0.532539 absTol eps) - assert(ndcg(0) ~== 0.508740 absTol eps) - assert(ndcg(1) ~== 0.296082 absTol eps) - assert(aveNdcg ~== 0.402411 absTol eps) + + assert(ndcgAtK(0)(4) ~== 0.508740 absTol eps) + assert(ndcgAtK(1)(2) ~== 0.296082 absTol eps) + + assert(metrics.ndcg(3) ~== (ndcgAtK(0)(2) + ndcgAtK(1)(2)) / 2 absTol eps) + assert(metrics.ndcg(5) ~== (ndcgAtK(0)(4) + ndcgAtK(1)(4)) / 2 absTol eps) + assert(metrics.ndcg(10) ~== (ndcgAtK(0)(9) + ndcgAtK(1)(9)) / 2 absTol eps) + assert(metrics.ndcg(15) ~== metrics.ndcg(10) absTol eps) + } }