From 84f4f0f4c5546e5b3cddb2655ab59508d04a121a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 10 May 2024 23:02:52 +0800 Subject: [PATCH 1/2] Be more lenient on floating point error in AUC. --- src/metric/auc.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 212a3a027d35..75857307ba49 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -311,7 +311,8 @@ class EvalAUC : public MetricNoCache { } auc = collective::GlobalRatio(ctx_, info, auc, fp * tp); if (!std::isnan(auc)) { - CHECK_LE(auc, 1.0); + CHECK_LE(auc, 1.0 + kRtEps); + auc = std::min(auc, 1.0); } } if (std::isnan(auc)) { From 887dd78f22ff99de85d38e64683918e0bd939131 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 11 May 2024 01:24:11 +0800 Subject: [PATCH 2/2] More places. --- src/metric/auc.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 75857307ba49..189c2b8e7269 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -177,7 +177,7 @@ double GroupRankingROC(Context const* ctx, common::Span predts, if (sum_w != 0) { auc /= sum_w; } - CHECK_LE(auc, 1.0f); + CHECK_LE(auc, 1.0 + kRtEps); return auc; } @@ -290,8 +290,8 @@ class EvalAUC : public MetricNoCache { auc = collective::GlobalRatio(ctx_, info, auc, static_cast(valid_groups)); if (!std::isnan(auc)) { - CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups - << ", valid groups: " << valid_groups; + CHECK_LE(auc, 1.0 + kRtEps) << "Total AUC across groups: " << auc * valid_groups + << ", valid groups: " << valid_groups; } } else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) { /**