From ef13dd31b1dc073f4732544739db5a252ae5b6e5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 18 Apr 2023 21:16:06 +0800 Subject: [PATCH] Rework the NDCG objective. (#9015) --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + src/common/math.h | 31 +- src/common/ranking_utils.h | 6 +- src/objective/lambdarank_obj.cc | 440 ++++++++++++++++++++ src/objective/lambdarank_obj.cu | 417 +++++++++++++++++++ src/objective/lambdarank_obj.h | 84 ++-- src/objective/objective.cc | 3 + src/objective/rank_obj.cu | 172 -------- tests/cpp/objective/test_lambdarank_obj.cc | 122 +++++- tests/cpp/objective/test_lambdarank_obj.cu | 18 + tests/cpp/objective/test_lambdarank_obj.h | 21 +- tests/cpp/objective/test_ranking_obj.cc | 45 -- tests/cpp/objective/test_ranking_obj_gpu.cu | 56 --- tests/python-gpu/test_gpu_eval_metrics.py | 16 +- 15 files changed, 1082 insertions(+), 351 deletions(-) create mode 100644 src/objective/lambdarank_obj.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 743bf0a66ce0..04f0a74a5308 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -33,6 +33,7 @@ OBJECTS= \ $(PKGROOT)/src/objective/regression_obj.o \ $(PKGROOT)/src/objective/multiclass_obj.o \ $(PKGROOT)/src/objective/rank_obj.o \ + $(PKGROOT)/src/objective/lambdarank_obj.o \ $(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/adaptive.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index a32d2fd2e45d..969cb7ff42b1 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -33,6 +33,7 @@ OBJECTS= \ $(PKGROOT)/src/objective/regression_obj.o \ $(PKGROOT)/src/objective/multiclass_obj.o \ $(PKGROOT)/src/objective/rank_obj.o \ + $(PKGROOT)/src/objective/lambdarank_obj.o \ $(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/adaptive.o \ diff --git a/src/common/math.h b/src/common/math.h index 71a494544be1..c4d794b5dadf 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2015 by Contributors +/** + * Copyright 2015-2023 by XGBoost Contributors * \file math.h * \brief additional math utils * \author Tianqi Chen @@ -7,16 +7,19 @@ #ifndef XGBOOST_COMMON_MATH_H_ #define XGBOOST_COMMON_MATH_H_ -#include +#include // for XGBOOST_DEVICE -#include -#include -#include -#include -#include +#include // for max +#include // for exp, abs, log, lgamma +#include // for numeric_limits +#include // for is_floating_point, conditional, is_signed, is_same, declval, enable_if +#include // for pair namespace xgboost { namespace common { + +template XGBOOST_DEVICE T Sqr(T const &w) { return w * w; } + /*! * \brief calculate the sigmoid of the input. * \param x input parameter @@ -30,9 +33,11 @@ XGBOOST_DEVICE inline float Sigmoid(float x) { return y; } -template -XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; } - +XGBOOST_DEVICE inline double Sigmoid(double x) { + auto denom = std::exp(-x) + 1.0; + auto y = 1.0 / denom; + return y; +} /*! * \brief Equality test for both integer and floating point. */ @@ -134,10 +139,6 @@ inline static bool CmpFirst(const std::pair &a, const std::pair &b) { return a.first > b.first; } -inline static bool CmpSecond(const std::pair &a, - const std::pair &b) { - return a.second > b.second; -} // Redefined here to workaround a VC bug that doesn't support overloading for integer // types. diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index bc071c2d6965..dd823a0d6e50 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -70,7 +70,7 @@ struct LambdaRankParam : public XGBoostParameter { // pairs // should be accessed by getter for auto configuration. // nolint so that we can keep the string name. - PairMethod lambdarank_pair_method{PairMethod::kMean}; // NOLINT + PairMethod lambdarank_pair_method{PairMethod::kTopK}; // NOLINT std::size_t lambdarank_num_pair_per_sample{NotSet()}; // NOLINT public: @@ -78,7 +78,7 @@ struct LambdaRankParam : public XGBoostParameter { // unbiased bool lambdarank_unbiased{false}; - double lambdarank_bias_norm{2.0}; + double lambdarank_bias_norm{1.0}; // ndcg bool ndcg_exp_gain{true}; @@ -135,7 +135,7 @@ struct LambdaRankParam : public XGBoostParameter { .set_default(false) .describe("Unbiased lambda mart. Use extended IPW to debias click position"); DMLC_DECLARE_FIELD(lambdarank_bias_norm) - .set_default(2.0) + .set_default(1.0) .set_lower_bound(0.0) .describe("Lp regularization for unbiased lambdarank."); DMLC_DECLARE_FIELD(ndcg_exp_gain) diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc new file mode 100644 index 000000000000..30957f81a4f7 --- /dev/null +++ b/src/objective/lambdarank_obj.cc @@ -0,0 +1,440 @@ +/** + * Copyright (c) 2023, XGBoost contributors + */ +#include "lambdarank_obj.h" + +#include // for DMLC_REGISTRY_FILE_TAG + +#include // for transform, copy, fill_n, min, max +#include // for pow, log2 +#include // for size_t +#include // for int32_t +#include // for operator!= +#include // for shared_ptr, __shared_ptr_access, allocator +#include // for operator<<, basic_ostream +#include // for char_traits, operator<, basic_string, string +#include // for apply, make_tuple +#include // for is_floating_point +#include // for pair, swap +#include // for vector + +#include "../common/error_msg.h" // for GroupWeight, LabelScoreSize +#include "../common/linalg_op.h" // for begin, cbegin, cend +#include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights +#include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC... +#include "../common/threading_utils.h" // for ParallelFor, Sched +#include "../common/transform_iterator.h" // for IndexTransformIter +#include "init_estimation.h" // for FitIntercept +#include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai... +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA +#include "xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All +#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE... +#include "xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE +#include "xgboost/span.h" // for Span, operator!= +#include "xgboost/string_view.h" // for operator<<, StringView +#include "xgboost/task.h" // for ObjInfo + +namespace xgboost::obj { +namespace cpu_impl { +void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, + linalg::VectorView lj_full, + linalg::Vector* p_ti_plus, + linalg::Vector* p_tj_minus, linalg::Vector* p_li, + linalg::Vector* p_lj, + std::shared_ptr p_cache) { + auto ti_plus = p_ti_plus->HostView(); + auto tj_minus = p_tj_minus->HostView(); + auto li = p_li->HostView(); + auto lj = p_lj->HostView(); + + auto gptr = p_cache->DataGroupPtr(ctx); + auto n_groups = p_cache->Groups(); + auto regularizer = p_cache->Param().Regularizer(); + + // Aggregate over query groups + for (bst_group_t g{0}; g < n_groups; ++g) { + auto begin = gptr[g]; + auto end = gptr[g + 1]; + std::size_t group_size = end - begin; + auto n = std::min(group_size, p_cache->MaxPositionSize()); + + auto g_li = li_full.Slice(linalg::Range(begin, end)); + auto g_lj = lj_full.Slice(linalg::Range(begin, end)); + + for (std::size_t i{0}; i < n; ++i) { + li(i) += g_li(i); + lj(i) += g_lj(i); + } + } + // The ti+ is not guaranteed to decrease since it depends on the |\delta Z| + // + // The update normalizes the ti+ to make ti+(0) equal to 1, which breaks the probability + // meaning. The reasoning behind the normalization is not clear, here we are just + // following the authors. + for (std::size_t i = 0; i < ti_plus.Size(); ++i) { + if (li(0) >= Eps64()) { + ti_plus(i) = std::pow(li(i) / li(0), regularizer); // eq.30 + } + if (lj(0) >= Eps64()) { + tj_minus(i) = std::pow(lj(i) / lj(0), regularizer); // eq.31 + } + assert(!std::isinf(ti_plus(i))); + assert(!std::isinf(tj_minus(i))); + } +} +} // namespace cpu_impl + +/** + * \brief Base class for pair-wise learning to rank. + * + * See `From RankNet to LambdaRank to LambdaMART: An Overview` for a description of the + * algorithm. + * + * In addition to ranking, this also implements `Unbiased LambdaMART: An Unbiased + * Pairwise Learning-to-Rank Algorithm`. + */ +template +class LambdaRankObj : public FitIntercept { + MetaInfo const* p_info_{nullptr}; + + // Update position biased for unbiased click data + void UpdatePositionBias() { + li_full_.SetDevice(ctx_->gpu_id); + lj_full_.SetDevice(ctx_->gpu_id); + li_.SetDevice(ctx_->gpu_id); + lj_.SetDevice(ctx_->gpu_id); + + if (ctx_->IsCPU()) { + cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), + lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, + &li_, &lj_, p_cache_); + } else { + cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), + lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, + &li_, &lj_, p_cache_); + } + + li_full_.Data()->Fill(0.0); + lj_full_.Data()->Fill(0.0); + + li_.Data()->Fill(0.0); + lj_.Data()->Fill(0.0); + } + + protected: + // L / tj-* (eq. 30) + linalg::Vector li_; + // L / ti+* (eq. 31) + linalg::Vector lj_; + // position bias ratio for relevant doc, ti+ (eq. 30) + linalg::Vector ti_plus_; + // position bias ratio for irrelevant doc, tj- (eq. 31) + linalg::Vector tj_minus_; + // li buffer for all samples + linalg::Vector li_full_; + // lj buffer for all samples + linalg::Vector lj_full_; + + ltr::LambdaRankParam param_; + // cache + std::shared_ptr p_cache_; + + [[nodiscard]] std::shared_ptr GetCache() const { + auto ptr = std::static_pointer_cast(p_cache_); + CHECK(ptr); + return ptr; + } + + // get group view for li/lj + linalg::VectorView GroupLoss(bst_group_t g, linalg::Vector* v) const { + auto gptr = p_cache_->DataGroupPtr(ctx_); + auto begin = gptr[g]; + auto end = gptr[g + 1]; + if (param_.lambdarank_unbiased) { + return v->HostView().Slice(linalg::Range(begin, end)); + } + return v->HostView(); + } + + // Calculate lambda gradient for each group on CPU. + template + void CalcLambdaForGroup(std::int32_t iter, common::Span g_predt, + linalg::VectorView g_label, float w, + common::Span g_rank, bst_group_t g, Delta delta, + common::Span g_gpair) { + std::fill_n(g_gpair.data(), g_gpair.size(), GradientPair{}); + auto p_gpair = g_gpair.data(); + + auto ti_plus = ti_plus_.HostView(); + auto tj_minus = tj_minus_.HostView(); + + auto li = GroupLoss(g, &li_full_); + auto lj = GroupLoss(g, &lj_full_); + + // Normalization, first used by LightGBM. + // https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298 + double sum_lambda{0.0}; + + auto delta_op = [&](auto const&... args) { return delta(args..., g); }; + + auto loop = [&](std::size_t i, std::size_t j) { + // higher/lower on the target ranked list + std::size_t rank_high = i, rank_low = j; + if (g_label(g_rank[rank_high]) == g_label(g_rank[rank_low])) { + return; + } + if (g_label(g_rank[rank_high]) < g_label(g_rank[rank_low])) { + std::swap(rank_high, rank_low); + } + + double cost; + auto pg = LambdaGrad(g_label, g_predt, g_rank, rank_high, rank_low, delta_op, + ti_plus, tj_minus, &cost); + auto ng = Repulse(pg); + + std::size_t idx_high = g_rank[rank_high]; + std::size_t idx_low = g_rank[rank_low]; + p_gpair[idx_high] += pg; + p_gpair[idx_low] += ng; + + if (unbiased) { + auto k = ti_plus.Size(); + // We can probably use all the positions. If we skip the update due to having + // high/low > k, we might be losing out too many pairs. On the other hand, if we + // cap the position, then we might be accumulating too many tail bias into the + // last tracked position. + // We use `idx_high` since it represents the original position from the label + // list, and label list is assumed to be sorted. + if (idx_high < k && idx_low < k) { + if (tj_minus(idx_low) >= Eps64()) { + li(idx_high) += cost / tj_minus(idx_low); // eq.30 + } + if (ti_plus(idx_high) >= Eps64()) { + lj(idx_low) += cost / ti_plus(idx_high); // eq.31 + } + } + } + + sum_lambda += -2.0 * static_cast(pg.GetGrad()); + }; + + MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); + if (sum_lambda > 0.0) { + double norm = std::log2(1.0 + sum_lambda) / sum_lambda; + std::transform(g_gpair.data(), g_gpair.data() + g_gpair.size(), g_gpair.data(), + [norm](GradientPair const& g) { return g * norm; }); + } + + auto w_norm = p_cache_->WeightNorm(); + std::transform(g_gpair.begin(), g_gpair.end(), g_gpair.begin(), + [&](GradientPair const& gpair) { return gpair * w * w_norm; }); + } + + public: + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(Loss::Name()); + out["lambdarank_param"] = ToJson(param_); + + auto save_bias = [](linalg::Vector const& in, Json out) { + auto& out_array = get(out); + out_array.resize(in.Size()); + auto h_in = in.HostView(); + std::copy(linalg::cbegin(h_in), linalg::cend(h_in), out_array.begin()); + }; + + if (param_.lambdarank_unbiased) { + out["ti+"] = F32Array(); + save_bias(ti_plus_, out["ti+"]); + out["tj-"] = F32Array(); + save_bias(tj_minus_, out["tj-"]); + } + } + void LoadConfig(Json const& in) override { + auto const& obj = get(in); + if (obj.find("lambdarank_param") != obj.cend()) { + FromJson(in["lambdarank_param"], ¶m_); + } + + if (param_.lambdarank_unbiased) { + auto load_bias = [](Json in, linalg::Vector* out) { + if (IsA(in)) { + // JSON + auto const& array = get(in); + out->Reshape(array.size()); + auto h_out = out->HostView(); + std::copy(array.cbegin(), array.cend(), linalg::begin(h_out)); + } else { + // UBJSON + auto const& array = get(in); + out->Reshape(array.size()); + auto h_out = out->HostView(); + std::transform(array.cbegin(), array.cend(), linalg::begin(h_out), + [](Json const& v) { return get(v); }); + } + }; + load_bias(in["ti+"], &ti_plus_); + load_bias(in["tj-"], &tj_minus_); + } + } + + [[nodiscard]] ObjInfo Task() const override { return ObjInfo{ObjInfo::kRanking}; } + + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { + CHECK_LE(info.labels.Shape(1), 1) << "multi-output for LTR is not yet supported."; + return 1; + } + + [[nodiscard]] const char* RankEvalMetric(StringView metric) const { + static thread_local std::string name; + if (param_.HasTruncation()) { + name = ltr::MakeMetricName(metric, param_.NumPair(), false); + } else { + name = ltr::MakeMetricName(metric, param_.NotSet(), false); + } + return name.c_str(); + } + + void GetGradient(HostDeviceVector const& predt, MetaInfo const& info, std::int32_t iter, + HostDeviceVector* out_gpair) override { + CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize(); + + // init/renew cache + if (!p_cache_ || p_info_ != &info || p_cache_->Param() != param_) { + p_cache_ = std::make_shared(ctx_, info, param_); + p_info_ = &info; + } + auto n_groups = p_cache_->Groups(); + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight(); + } + + if (ti_plus_.Size() == 0 && param_.lambdarank_unbiased) { + CHECK_EQ(iter, 0); + ti_plus_ = linalg::Constant(ctx_, 1.0, p_cache_->MaxPositionSize()); + tj_minus_ = linalg::Constant(ctx_, 1.0, p_cache_->MaxPositionSize()); + + li_ = linalg::Zeros(ctx_, p_cache_->MaxPositionSize()); + lj_ = linalg::Zeros(ctx_, p_cache_->MaxPositionSize()); + + li_full_ = linalg::Zeros(ctx_, info.num_row_); + lj_full_ = linalg::Zeros(ctx_, info.num_row_); + } + static_cast(this)->GetGradientImpl(iter, predt, info, out_gpair); + + if (param_.lambdarank_unbiased) { + this->UpdatePositionBias(); + } + } +}; + +class LambdaRankNDCG : public LambdaRankObj { + public: + template + void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span g_predt, + linalg::VectorView g_label, float w, + common::Span g_rank, + common::Span g_gpair, + linalg::VectorView inv_IDCG, + common::Span discount, bst_group_t g) { + auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low, + bst_group_t g) { + static_assert(std::is_floating_point::value); + return DeltaNDCG(y_high, y_low, rank_high, rank_low, inv_IDCG(g), discount); + }; + this->CalcLambdaForGroup(iter, g_predt, g_label, w, g_rank, g, delta, g_gpair); + } + + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, + const MetaInfo& info, HostDeviceVector* out_gpair) { + if (ctx_->IsCUDA()) { + cuda_impl::LambdaRankGetGradientNDCG( + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), + tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), + out_gpair); + return; + } + + bst_group_t n_groups = p_cache_->Groups(); + auto gptr = p_cache_->DataGroupPtr(ctx_); + + out_gpair->Resize(info.num_row_); + auto h_gpair = out_gpair->HostSpan(); + auto h_predt = predt.ConstHostSpan(); + auto h_label = info.labels.HostView(); + auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; + + auto dct = GetCache()->Discount(ctx_); + auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); + auto inv_IDCG = GetCache()->InvIDCG(ctx_); + + common::ParallelFor(n_groups, ctx_->Threads(), common::Sched::Guided(), [&](auto g) { + std::size_t cnt = gptr[g + 1] - gptr[g]; + auto w = h_weight[g]; + auto g_predt = h_predt.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_label = h_label.Slice(make_range(g), 0); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + + auto args = + std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g_gpair, inv_IDCG, dct, g); + + if (param_.lambdarank_unbiased) { + if (param_.ndcg_exp_gain) { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } else { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } + } else { + if (param_.ndcg_exp_gain) { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } else { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } + } + }); + } + + static char const* Name() { return "rank:ndcg"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { + return this->RankEvalMetric("ndcg"); + } + [[nodiscard]] Json DefaultMetricConfig() const override { + Json config{Object{}}; + config["name"] = String{DefaultEvalMetric()}; + config["lambdarank_param"] = ToJson(param_); + return config; + } +}; + +namespace cuda_impl { +#if !defined(XGBOOST_USE_CUDA) +void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector const&, + const MetaInfo&, std::shared_ptr, + linalg::VectorView, // input bias ratio + linalg::VectorView, // input bias ratio + linalg::VectorView, linalg::VectorView, + HostDeviceVector*) { + common::AssertGPUSupport(); +} + +void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView, + linalg::VectorView, linalg::Vector*, + linalg::Vector*, linalg::Vector*, + linalg::Vector*, std::shared_ptr) { + common::AssertGPUSupport(); +} +#endif // !defined(XGBOOST_USE_CUDA) +} // namespace cuda_impl + +XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name()) + .describe("LambdaRank with NDCG loss as objective") + .set_body([]() { return new LambdaRankNDCG{}; }); + +DMLC_REGISTRY_FILE_TAG(lambdarank_obj); +} // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index eb82b17b4008..27b5872a8a49 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -37,6 +37,312 @@ namespace xgboost::obj { DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu); namespace cuda_impl { +namespace { +/** + * \brief Calculate minimum value of bias for floating point truncation. + */ +void MinBias(Context const* ctx, std::shared_ptr p_cache, + linalg::VectorView t_plus, linalg::VectorView tj_minus, + common::Span d_min) { + CHECK_EQ(d_min.size(), 2); + auto cuctx = ctx->CUDACtx(); + + auto k = t_plus.Size(); + auto const& p = p_cache->Param(); + CHECK_GT(k, 0); + CHECK_EQ(k, p_cache->MaxPositionSize()); + + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return i * k; }); + auto val_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) { + if (i >= k) { + return std::abs(tj_minus(i - k)); + } + return std::abs(t_plus(i)); + }); + std::size_t bytes; + cub::DeviceSegmentedReduce::Min(nullptr, bytes, val_it, d_min.data(), 2, key_it, key_it + 1, + cuctx->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Min(temp.data().get(), bytes, val_it, d_min.data(), 2, key_it, + key_it + 1, cuctx->Stream()); +} + +/** + * \brief Type for gradient statistic. (Gradient, cost for unbiased LTR, normalization factor) + */ +using GradCostNorm = thrust::tuple; + +/** + * \brief Obtain and update the gradient for one pair. + */ +template +struct GetGradOp { + MakePairsOp make_pair; + Delta delta; + + bool need_update; + + auto __device__ operator()(std::size_t idx) -> GradCostNorm { + auto const& args = make_pair.args; + auto g = dh::SegmentId(args.d_threads_group_ptr, idx); + + auto data_group_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; + // obtain group segment data. + auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); + auto g_predt = args.predts.subspan(data_group_begin, n_data); + auto g_gpair = args.gpairs.subspan(data_group_begin, n_data).data(); + auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data); + + auto [i, j] = make_pair(idx, g); + + std::size_t rank_high = i, rank_low = j; + if (g_label(g_rank[i]) == g_label(g_rank[j])) { + return thrust::make_tuple(GradientPair{}, 0.0, 0.0); + } + if (g_label(g_rank[i]) < g_label(g_rank[j])) { + thrust::swap(rank_high, rank_low); + } + + double cost{0}; + + auto delta_op = [&](auto const&... args) { return delta(args..., g); }; + GradientPair pg = LambdaGrad(g_label, g_predt, g_rank, rank_high, rank_low, delta_op, + args.ti_plus, args.tj_minus, &cost); + + std::size_t idx_high = g_rank[rank_high]; + std::size_t idx_low = g_rank[rank_low]; + + if (need_update) { + // second run, update the gradient + + auto ng = Repulse(pg); + + auto gr = args.d_roundings(g); + // positive gradient truncated + auto pgt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), pg.GetGrad()), + common::TruncateWithRounding(gr.GetHess(), pg.GetHess())}; + // negative gradient truncated + auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()), + common::TruncateWithRounding(gr.GetHess(), ng.GetHess())}; + + dh::AtomicAddGpair(g_gpair + idx_high, pgt); + dh::AtomicAddGpair(g_gpair + idx_low, ngt); + } + + if (unbiased && need_update) { + // second run, update the cost + assert(args.tj_minus.Size() == args.ti_plus.Size() && "Invalid size of position bias"); + + auto g_li = args.li.Slice(linalg::Range(data_group_begin, data_group_begin + n_data)); + auto g_lj = args.lj.Slice(linalg::Range(data_group_begin, data_group_begin + n_data)); + + if (idx_high < args.ti_plus.Size() && idx_low < args.ti_plus.Size()) { + if (args.tj_minus(idx_low) >= Eps64()) { + // eq.30 + atomicAdd(&g_li(idx_high), common::TruncateWithRounding(args.d_cost_rounding[0], + cost / args.tj_minus(idx_low))); + } + if (args.ti_plus(idx_high) >= Eps64()) { + // eq.31 + atomicAdd(&g_lj(idx_low), common::TruncateWithRounding(args.d_cost_rounding[0], + cost / args.ti_plus(idx_high))); + } + } + } + return thrust::make_tuple(GradientPair{std::abs(pg.GetGrad()), std::abs(pg.GetHess())}, + std::abs(cost), -2.0 * static_cast(pg.GetGrad())); + } +}; + +template +struct MakeGetGrad { + MakePairsOp make_pair; + Delta delta; + + [[nodiscard]] KernelInputs const& Args() const { return make_pair.args; } + + MakeGetGrad(KernelInputs args, Delta d) : make_pair{args}, delta{std::move(d)} {} + + GetGradOp operator()(bool need_update) { + return GetGradOp{make_pair, delta, need_update}; + } +}; + +/** + * \brief Calculate gradient for all pairs using update op created by make_get_grad. + * + * We need to run gradient calculation twice, the first time gathers infomation like + * maximum gradient, maximum cost, and the normalization term using reduction. The second + * time performs the actual update. + * + * Without normalization, we only need to run it once since we can manually calculate + * the bounds of gradient (NDCG \in [0, 1], delta_NDCG \in [0, 1], ti+/tj- are from the + * previous iteration so the bound can be calculated for current iteration). However, if + * normalization is used, the delta score is un-bounded and we need to obtain the sum + * gradient. As a tradeoff, we simply run the kernel twice, once as reduction, second + * one as for_each. + * + * Alternatively, we can bound the delta score by limiting the output of the model using + * sigmoid for binary output and some normalization for multi-level. But effect to the + * accuracy is not known yet, and it's only used by GPU. + * + * For performance, the segmented sort for sorted scores is the bottleneck and takes up + * about half of the time, while the reduction and for_each takes up the second half. + */ +template +void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr p_cache, + MakeGetGrad make_get_grad) { + auto n_groups = p_cache->Groups(); + auto d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr(); + auto d_gptr = p_cache->DataGroupPtr(ctx); + auto d_gpair = make_get_grad.Args().gpairs; + + /** + * First pass, gather info for normalization and rounding factor. + */ + auto val_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + make_get_grad(false)); + auto reduction_op = [] XGBOOST_DEVICE(GradCostNorm const& l, + GradCostNorm const& r) -> GradCostNorm { + // get maximum gradient for each group, along with cost and the normalization term + auto const& lg = thrust::get<0>(l); + auto const& rg = thrust::get<0>(r); + auto grad = std::max(lg.GetGrad(), rg.GetGrad()); + auto hess = std::max(lg.GetHess(), rg.GetHess()); + auto cost = std::max(thrust::get<1>(l), thrust::get<1>(r)); + double sum_lambda = thrust::get<2>(l) + thrust::get<2>(r); + return thrust::make_tuple(GradientPair{std::abs(grad), std::abs(hess)}, cost, sum_lambda); + }; + auto init = thrust::make_tuple(GradientPair{0.0f, 0.0f}, 0.0, 0.0); + common::Span d_max_lambdas = p_cache->MaxLambdas(ctx, n_groups); + CHECK_EQ(n_groups * sizeof(GradCostNorm), d_max_lambdas.size_bytes()); + + std::size_t bytes; + cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, d_max_lambdas.data(), n_groups, + d_threads_group_ptr.data(), d_threads_group_ptr.data() + 1, + reduction_op, init, ctx->CUDACtx()->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Reduce( + temp.data().get(), bytes, val_it, d_max_lambdas.data(), n_groups, d_threads_group_ptr.data(), + d_threads_group_ptr.data() + 1, reduction_op, init, ctx->CUDACtx()->Stream()); + + dh::TemporaryArray min_bias(2); + auto d_min_bias = dh::ToSpan(min_bias); + if (unbiased) { + MinBias(ctx, p_cache, make_get_grad.Args().ti_plus, make_get_grad.Args().tj_minus, d_min_bias); + } + /** + * Create rounding factors + */ + auto d_cost_rounding = p_cache->CUDACostRounding(ctx); + auto d_rounding = p_cache->CUDARounding(ctx); + dh::LaunchN(n_groups, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t g) mutable { + auto group_size = d_gptr[g + 1] - d_gptr[g]; + auto const& max_grad = thrust::get<0>(d_max_lambdas[g]); + // float group size + auto fgs = static_cast(group_size); + auto grad = common::CreateRoundingFactor(fgs * max_grad.GetGrad(), group_size); + auto hess = common::CreateRoundingFactor(fgs * max_grad.GetHess(), group_size); + d_rounding(g) = GradientPair{grad, hess}; + + auto cost = thrust::get<1>(d_max_lambdas[g]); + if (unbiased) { + cost /= std::min(d_min_bias[0], d_min_bias[1]); + d_cost_rounding[0] = common::CreateRoundingFactor(fgs * cost, group_size); + } + }); + + /** + * Second pass, actual update to gradient and bias. + */ + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), + p_cache->CUDAThreads(), make_get_grad(true)); + + /** + * Lastly, normalization and weight. + */ + auto d_weights = common::MakeOptionalWeights(ctx, info.weights_); + auto w_norm = p_cache->WeightNorm(); + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + auto g = dh::SegmentId(d_gptr, i); + auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); + // Normalization + if (sum_lambda > 0.0) { + double norm = std::log2(1.0 + sum_lambda) / sum_lambda; + d_gpair[i] *= norm; + } + d_gpair[i] *= (d_weights[g] * w_norm); + }); +} + +/** + * \brief Handles boilerplate code like getting device span. + */ +template +void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const& preds, + const MetaInfo& info, std::shared_ptr p_cache, Delta delta, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + // boilerplate + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + auto n_groups = p_cache->Groups(); + + info.labels.SetDevice(device_id); + preds.SetDevice(device_id); + out_gpair->SetDevice(device_id); + out_gpair->Resize(preds.Size()); + + CHECK(p_cache); + + auto d_rounding = p_cache->CUDARounding(ctx); + auto d_cost_rounding = p_cache->CUDACostRounding(ctx); + + CHECK_NE(d_rounding.Size(), 0); + + auto label = info.labels.View(ctx->gpu_id); + auto predts = preds.ConstDeviceSpan(); + auto gpairs = out_gpair->DeviceSpan(); + thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.data(), gpairs.size(), GradientPair{0.0f, 0.0f}); + + auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr(); + auto const d_gptr = p_cache->DataGroupPtr(ctx); + auto const rank_idx = p_cache->SortedIdx(ctx, predts); + + auto const unbiased = p_cache->Param().lambdarank_unbiased; + + common::Span d_y_sorted_idx; + if (!p_cache->Param().HasTruncation()) { + d_y_sorted_idx = SortY(ctx, info, rank_idx, p_cache); + } + + KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr, + rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data(), + d_y_sorted_idx, iter}; + + // dispatch based on unbiased and truncation + if (p_cache->Param().HasTruncation()) { + if (unbiased) { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } else { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } + } else { + if (unbiased) { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } else { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } + } +} +} // anonymous namespace + common::Span SortY(Context const* ctx, MetaInfo const& info, common::Span d_rank, std::shared_ptr p_cache) { @@ -58,5 +364,116 @@ common::Span SortY(Context const* ctx, MetaInfo const& info, common::SegmentedArgSort(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx); return d_y_sorted_idx; } + +void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, + const HostDeviceVector& preds, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + // boilerplate + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + auto const d_inv_IDCG = p_cache->InvIDCG(ctx); + auto const discount = p_cache->Discount(ctx); + + info.labels.SetDevice(device_id); + preds.SetDevice(device_id); + + auto const exp_gain = p_cache->Param().ndcg_exp_gain; + auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, bst_group_t g) { + return exp_gain ? DeltaNDCG(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount) + : DeltaNDCG(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount); + }; + Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair); +} + +namespace { +struct ReduceOp { + template + Tup XGBOOST_DEVICE operator()(Tup const& l, Tup const& r) { + return thrust::make_tuple(thrust::get<0>(l) + thrust::get<0>(r), + thrust::get<1>(l) + thrust::get<1>(r)); + } +}; +} // namespace + +void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, + linalg::VectorView lj_full, + linalg::Vector* p_ti_plus, + linalg::Vector* p_tj_minus, + linalg::Vector* p_li, // loss + linalg::Vector* p_lj, + std::shared_ptr p_cache) { + auto const d_group_ptr = p_cache->DataGroupPtr(ctx); + auto n_groups = d_group_ptr.size() - 1; + + auto ti_plus = p_ti_plus->View(ctx->gpu_id); + auto tj_minus = p_tj_minus->View(ctx->gpu_id); + + auto li = p_li->View(ctx->gpu_id); + auto lj = p_lj->View(ctx->gpu_id); + CHECK_EQ(li.Size(), ti_plus.Size()); + + auto const& param = p_cache->Param(); + auto regularizer = param.Regularizer(); + std::size_t k = p_cache->MaxPositionSize(); + + CHECK_EQ(li.Size(), k); + CHECK_EQ(lj.Size(), k); + // reduce li_full to li for each group. + auto make_iter = [&](linalg::VectorView l_full) { + auto l_it = [=] XGBOOST_DEVICE(std::size_t i) { + // group index + auto g = i % n_groups; + // rank is the position within a group, also the segment index + auto r = i / n_groups; + + auto begin = d_group_ptr[g]; + std::size_t group_size = d_group_ptr[g + 1] - begin; + auto n = std::min(group_size, k); + // r can be greater than n since we allocate threads based on truncation level + // instead of actual group size. + if (r >= n) { + return 0.0; + } + return l_full(r + begin); + }; + return l_it; + }; + auto li_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), make_iter(li_full)); + auto lj_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), make_iter(lj_full)); + // k segments, each segment has size n_groups. + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) { return i * n_groups; }); + auto val_it = thrust::make_zip_iterator(thrust::make_tuple(li_it, lj_it)); + auto out_it = + thrust::make_zip_iterator(thrust::make_tuple(li.Values().data(), lj.Values().data())); + + auto init = thrust::make_tuple(0.0, 0.0); + std::size_t bytes; + cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, out_it, k, key_it, key_it + 1, + ReduceOp{}, init, ctx->CUDACtx()->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Reduce(temp.data().get(), bytes, val_it, out_it, k, key_it, + key_it + 1, ReduceOp{}, init, ctx->CUDACtx()->Stream()); + + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), li.Size(), + [=] XGBOOST_DEVICE(std::size_t i) mutable { + if (li(0) >= Eps64()) { + ti_plus(i) = std::pow(li(i) / li(0), regularizer); + } + if (lj(0) >= Eps64()) { + tj_minus(i) = std::pow(lj(i) / lj(0), regularizer); + } + assert(!std::isinf(ti_plus(i))); + assert(!std::isinf(tj_minus(i))); + }); +} } // namespace cuda_impl } // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h index 3adb27a2e533..0eb06e27cdc4 100644 --- a/src/objective/lambdarank_obj.h +++ b/src/objective/lambdarank_obj.h @@ -1,5 +1,15 @@ /** - * Copyright 2023 XGBoost contributors + * Copyright 2023, XGBoost contributors + * + * Vocabulary explanation: + * + * There are two different lists we need to handle in the objective, first is the list of + * labels (relevance degree) provided by the user. Its order has no particular meaning + * when bias estimation is NOT used. Another one is generated by our model, sorted index + * based on prediction scores. `rank_high` refers to the position index of the model rank + * list that is higher than `rank_low`, while `idx_high` refers to where does the + * `rank_high` sample comes from. Simply put, `rank_high` indexes into the rank list + * obtained from the model, while `idx_high` indexes into the user provided sample list. */ #ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ #define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ @@ -25,14 +35,19 @@ #include "xgboost/span.h" // for Span namespace xgboost::obj { +double constexpr Eps64() { return 1e-16; } + template -XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low, - double inv_IDCG, common::Span discount) { +XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, double inv_IDCG, + common::Span discount) { + // Use rank_high instead of idx_high as we are calculating discount based on ranks + // provided by the model. double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high; - double discount_high = discount[r_high]; + double discount_high = discount[rank_high]; double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low; - double discount_low = discount[r_low]; + double discount_low = discount[rank_low]; double original = gain_high * discount_high + gain_low * discount_low; double changed = gain_low * discount_high + gain_high * discount_low; @@ -70,9 +85,9 @@ template XGBOOST_DEVICE GradientPair LambdaGrad(linalg::VectorView labels, common::Span predts, common::Span sorted_idx, - std::size_t rank_high, // cordiniate - std::size_t rank_low, // cordiniate - Delta delta, // delta score + std::size_t rank_high, // higher index on the model rank list + std::size_t rank_low, // lower index on the model rank list + Delta delta, // function to calculate delta score linalg::VectorView t_plus, // input bias ratio linalg::VectorView t_minus, // input bias ratio double* p_cost) { @@ -95,30 +110,34 @@ LambdaGrad(linalg::VectorView labels, common::Span pre // Use double whenever possible as we are working on the exp space. double delta_score = std::abs(s_high - s_low); - double sigmoid = common::Sigmoid(s_high - s_low); + double const sigmoid = common::Sigmoid(s_high - s_low); // Change in metric score like \delta NDCG or \delta MAP double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low)); if (best_score != worst_score) { - delta_metric /= (delta_score + kRtEps); + delta_metric /= (delta_score + 0.01); } if (unbiased) { *p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric; } - constexpr double kEps = 1e-16; auto lambda_ij = (sigmoid - 1.0) * delta_metric; - auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0; + auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric * 2.0; auto k = t_plus.Size(); assert(t_minus.Size() == k && "Invalid size of position bias"); - if (unbiased && idx_high < k && idx_low < k) { - lambda_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps); - hessian_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps); + // We need to skip samples that exceed the maximum number of tracked positions, and + // samples that have low probability and might bring us floating point issues. + if (unbiased && idx_high < k && idx_low < k && t_minus(idx_low) >= Eps64() && + t_plus(idx_high) >= Eps64()) { + // The index should be ranks[idx_low], since we assume label is sorted, this reduces + // to `idx_low`, which represents the position on the input list, as explained in the + // file header. + lambda_ij /= (t_plus(idx_high) * t_minus(idx_low)); + hessian_ij /= (t_plus(idx_high) * t_minus(idx_low)); } - auto pg = GradientPair{static_cast(lambda_ij), static_cast(hessian_ij)}; return pg; } @@ -137,27 +156,6 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, linalg::VectorView li, linalg::VectorView lj, HostDeviceVector* out_gpair); -/** - * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. - */ -void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, - std::shared_ptr p_cache); - -void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, - HostDeviceVector const& predt, MetaInfo const& info, - std::shared_ptr p_cache, - linalg::VectorView t_plus, // input bias ratio - linalg::VectorView t_minus, // input bias ratio - linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair); - -void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, - HostDeviceVector const& predt, const MetaInfo& info, - std::shared_ptr p_cache, - linalg::VectorView ti_plus, // input bias ratio - linalg::VectorView tj_minus, // input bias ratio - linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair); void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, linalg::VectorView lj_full, @@ -167,18 +165,6 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView p_cache); } // namespace cuda_impl -namespace cpu_impl { -/** - * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. - * - * \param label Ground truth relevance label. - * \param rank_idx Sorted index of prediction. - * \param p_cache An initialized MAPCache. - */ -void MAPStat(Context const* ctx, linalg::VectorView label, - common::Span rank_idx, std::shared_ptr p_cache); -} // namespace cpu_impl - /** * \param Construct pairs on CPU * diff --git a/src/objective/objective.cc b/src/objective/objective.cc index d3b01d80bf27..7d2c37811d1a 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -48,12 +48,15 @@ DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu); DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu); DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu); DMLC_REGISTRY_LINK_TAG(rank_obj_gpu); +DMLC_REGISTRY_LINK_TAG(lambdarank_obj); +DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu); #else DMLC_REGISTRY_LINK_TAG(regression_obj); DMLC_REGISTRY_LINK_TAG(quantile_obj); DMLC_REGISTRY_LINK_TAG(hinge_obj); DMLC_REGISTRY_LINK_TAG(multiclass_obj); DMLC_REGISTRY_LINK_TAG(rank_obj); +DMLC_REGISTRY_LINK_TAG(lambdarank_obj); #endif // XGBOOST_USE_CUDA } // namespace obj } // namespace xgboost diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index f1c8702102df..23613d93d9d3 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -207,174 +207,6 @@ class IndexablePredictionSorter { }; #endif -// beta version: NDCG lambda rank -class NDCGLambdaWeightComputer -#if defined(__CUDACC__) - : public IndexablePredictionSorter -#endif -{ - public: -#if defined(__CUDACC__) - // This function object computes the item's DCG value - class ComputeItemDCG : public thrust::unary_function { - public: - XGBOOST_DEVICE ComputeItemDCG(const common::Span &dsorted_labels, - const common::Span &dgroups, - const common::Span &gidxs) - : dsorted_labels_(dsorted_labels), - dgroups_(dgroups), - dgidxs_(gidxs) {} - - // Compute DCG for the item at 'idx' - __device__ __forceinline__ float operator()(uint32_t idx) const { - return ComputeItemDCGWeight(dsorted_labels_[idx], idx - dgroups_[dgidxs_[idx]]); - } - - private: - const common::Span dsorted_labels_; // Labels sorted within a group - const common::Span dgroups_; // The group indices - where each group - // begins and ends - const common::Span dgidxs_; // The group each items belongs to - }; - - // Type containing device pointers that can be cheaply copied on the kernel - class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { - public: - NDCGLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, - const NDCGLambdaWeightComputer &lwc) - : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), - dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {} - - // Adjust the items weight by this value - __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - if (dgroup_dcgs_[gidx] == 0.0) return 0.0f; - - uint32_t group_begin = dgroups_[gidx]; - - auto pos_lab_orig_posn = dorig_pos_[pidx]; - auto neg_lab_orig_posn = dorig_pos_[nidx]; - KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn); - - // Note: the label positive and negative indices are relative to the entire dataset. - // Hence, scale them back to an index within the group - auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin; - auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin; - return NDCGLambdaWeightComputer::ComputeDeltaWeight( - pos_pred_pos, neg_pred_pos, - static_cast(dsorted_labels_[pidx]), static_cast(dsorted_labels_[nidx]), - dgroup_dcgs_[gidx]); - } - - private: - const common::Span dgroup_dcgs_; // Group DCG values - }; - - NDCGLambdaWeightComputer(const bst_float *dpreds, - const bst_float*, - const dh::SegmentSorter &segment_label_sorter) - : IndexablePredictionSorter(dpreds, segment_label_sorter), - dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f), - weight_multiplier_(segment_label_sorter, *this) { - const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan(); - - // Allocator to be used for managing space overhead while performing transformed reductions - dh::XGBCachingDeviceAllocator alloc; - - // Compute each elements DCG values and reduce them across groups concurrently. - auto end_range = - thrust::reduce_by_key(thrust::cuda::par(alloc), - dh::tcbegin(group_segments), dh::tcend(group_segments), - thrust::make_transform_iterator( - // The indices need not be sequential within a group, as we care only - // about the sum of items DCG values within a group - dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()), - ComputeItemDCG(segment_label_sorter.GetItemsSpan(), - segment_label_sorter.GetGroupsSpan(), - group_segments)), - thrust::make_discard_iterator(), // We don't care for the group indices - dgroup_dcg_.begin()); // Sum of the item's DCG values in the group - CHECK_EQ(static_cast(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size()); - } - - inline const common::Span GetGroupDcgsSpan() const { - return { dgroup_dcg_.data().get(), dgroup_dcg_.size() }; - } - - inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const { - return weight_multiplier_; - } -#endif - - static void GetLambdaWeight(const std::vector &sorted_list, - std::vector *io_pairs) { - std::vector &pairs = *io_pairs; - float IDCG; // NOLINT - { - std::vector labels(sorted_list.size()); - for (size_t i = 0; i < sorted_list.size(); ++i) { - labels[i] = sorted_list[i].label; - } - std::stable_sort(labels.begin(), labels.end(), std::greater<>()); - IDCG = ComputeGroupDCGWeight(&labels[0], labels.size()); - } - if (IDCG == 0.0) { - for (auto & pair : pairs) { - pair.weight = 0.0f; - } - } else { - for (auto & pair : pairs) { - unsigned pos_idx = pair.pos_index; - unsigned neg_idx = pair.neg_index; - pair.weight *= ComputeDeltaWeight(pos_idx, neg_idx, - sorted_list[pos_idx].label, sorted_list[neg_idx].label, - IDCG); - } - } - } - - static char const* Name() { - return "rank:ndcg"; - } - - inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels, uint32_t size) { - double sumdcg = 0.0; - for (uint32_t i = 0; i < size; ++i) { - sumdcg += ComputeItemDCGWeight(sorted_labels[i], i); - } - - return static_cast(sumdcg); - } - - private: - XGBOOST_DEVICE inline static bst_float ComputeItemDCGWeight(unsigned label, uint32_t idx) { - return (label != 0) ? (((1 << label) - 1) / std::log2(static_cast(idx + 2))) : 0; - } - - // Compute the weight adjustment for an item within a group: - // pos_pred_pos => Where does the positive label live, had the list been sorted by prediction - // neg_pred_pos => Where does the negative label live, had the list been sorted by prediction - // pos_label => positive label value from sorted label list - // neg_label => negative label value from sorted label list - XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t pos_pred_pos, - uint32_t neg_pred_pos, - int pos_label, int neg_label, - float idcg) { - float pos_loginv = 1.0f / std::log2(pos_pred_pos + 2.0f); - float neg_loginv = 1.0f / std::log2(neg_pred_pos + 2.0f); - bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv; - float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv; - bst_float delta = (original - changed) * (1.0f / idcg); - if (delta < 0.0f) delta = - delta; - return delta; - } - -#if defined(__CUDACC__) - dh::caching_device_vector dgroup_dcg_; - // This computes the adjustment to the weight - const NDCGLambdaWeightMultiplier weight_multiplier_; -#endif -}; - class MAPLambdaWeightComputer #if defined(__CUDACC__) : public IndexablePredictionSorter @@ -948,10 +780,6 @@ XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name() .describe("Pairwise rank objective.") .set_body([]() { return new LambdaRankObj(); }); -XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, NDCGLambdaWeightComputer::Name()) -.describe("LambdaRank with NDCG as objective.") -.set_body([]() { return new LambdaRankObj(); }); - XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name()) .describe("LambdaRank with MAP as objective.") .set_body([]() { return new LambdaRankObj(); }); diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index 11cbf6bec3f6..d02a55c1b7b8 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -5,6 +5,7 @@ #include // for Test, Message, TestPartResult, CmpHel... +#include // for sort #include // for size_t #include // for initializer_list #include // for map @@ -13,7 +14,6 @@ #include // for char_traits, basic_string, string #include // for vector -#include "../../../src/common/ranking_utils.h" // for LambdaRankParam #include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam #include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload #include "xgboost/base.h" // for GradientPair, bst_group_t, Args @@ -25,6 +25,126 @@ #include "xgboost/span.h" // for Span namespace xgboost::obj { +TEST(LambdaRank, NDCGJsonIO) { + Context ctx; + TestNDCGJsonIO(&ctx); +} + +void TestNDCGGPair(Context const* ctx) { + { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; + obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + CheckConfigReload(obj, "rank:ndcg"); + + // No gain in swapping 2 documents. + CheckRankingObjFunction(obj, + {1, 1, 1, 1}, + {1, 1, 1, 1}, + {1.0f, 1.0f}, + {0, 2, 4}, + {0.0f, -0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 0.0f}); + } + { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; + obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + // Test with setting sample weight to second query group + CheckRankingObjFunction(obj, + {0, 0.1f, 0, 0.1f}, + {0, 1, 0, 1}, + {2.0f, 0.0f}, + {0, 2, 4}, + {2.06611f, -2.06611f, 0.0f, 0.0f}, + {2.169331f, 2.169331f, 0.0f, 0.0f}); + + CheckRankingObjFunction(obj, + {0, 0.1f, 0, 0.1f}, + {0, 1, 0, 1}, + {2.0f, 2.0f}, + {0, 2, 4}, + {2.06611f, -2.06611f, 2.06611f, -2.06611f}, + {2.169331f, 2.169331f, 2.169331f, 2.169331f}); + } + + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; + obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + + HostDeviceVector predts{0, 1, 0, 1}; + MetaInfo info; + info.labels = linalg::Tensor{{0, 1, 0, 1}, {4, 1}, GPUIDX}; + info.group_ptr_ = {0, 2, 4}; + info.num_row_ = 4; + HostDeviceVector gpairs; + obj->GetGradient(predts, info, 0, &gpairs); + ASSERT_EQ(gpairs.Size(), predts.Size()); + + { + predts = {1, 0, 1, 0}; + HostDeviceVector gpairs; + obj->GetGradient(predts, info, 0, &gpairs); + for (size_t i = 0; i < gpairs.Size(); ++i) { + ASSERT_GT(gpairs.HostSpan()[i].GetHess(), 0); + } + ASSERT_LT(gpairs.HostSpan()[1].GetGrad(), 0); + ASSERT_LT(gpairs.HostSpan()[3].GetGrad(), 0); + + ASSERT_GT(gpairs.HostSpan()[0].GetGrad(), 0); + ASSERT_GT(gpairs.HostSpan()[2].GetGrad(), 0); + + info.weights_ = {2, 3}; + HostDeviceVector weighted_gpairs; + obj->GetGradient(predts, info, 0, &weighted_gpairs); + auto const& h_gpairs = gpairs.ConstHostSpan(); + auto const& h_weighted_gpairs = weighted_gpairs.ConstHostSpan(); + for (size_t i : {0ul, 1ul}) { + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 2.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 2.0f); + } + for (size_t i : {2ul, 3ul}) { + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 3.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 3.0f); + } + } + + ASSERT_NO_THROW(obj->DefaultEvalMetric()); +} + +TEST(LambdaRank, NDCGGPair) { + Context ctx; + TestNDCGGPair(&ctx); +} + +void TestUnbiasedNDCG(Context const* ctx) { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; + obj->Configure(Args{{"lambdarank_pair_method", "topk"}, + {"lambdarank_unbiased", "true"}, + {"lambdarank_bias_norm", "0"}}); + std::shared_ptr p_fmat{RandomDataGenerator{10, 1, 0.0f}.GenerateDMatrix(true, false, 2)}; + auto h_label = p_fmat->Info().labels.HostView().Values(); + // Move clicked samples to the beginning. + std::sort(h_label.begin(), h_label.end(), std::greater<>{}); + HostDeviceVector predt(p_fmat->Info().num_row_, 1.0f); + + HostDeviceVector out_gpair; + obj->GetGradient(predt, p_fmat->Info(), 0, &out_gpair); + + Json config{Object{}}; + obj->SaveConfig(&config); + auto ti_plus = get(config["ti+"]); + ASSERT_FLOAT_EQ(ti_plus[0], 1.0); + // bias is non-increasing when prediction is constant. (constant cost on swapping documents) + for (std::size_t i = 1; i < ti_plus.size(); ++i) { + ASSERT_LE(ti_plus[i], ti_plus[i - 1]); + } + auto tj_minus = get(config["tj-"]); + ASSERT_FLOAT_EQ(tj_minus[0], 1.0); +} + +TEST(LambdaRank, UnbiasedNDCG) { + Context ctx; + TestUnbiasedNDCG(&ctx); +} + void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt) { out_predt->SetDevice(ctx->gpu_id); MetaInfo& info = *out_info; diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu index 03ccdef8b15c..01d020dda1cd 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cu +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -12,6 +12,18 @@ #include "test_lambdarank_obj.h" namespace xgboost::obj { +TEST(LambdaRank, GPUNDCGJsonIO) { + Context ctx; + ctx.gpu_id = 0; + TestNDCGJsonIO(&ctx); +} + +TEST(LambdaRank, GPUNDCGGPair) { + Context ctx; + ctx.gpu_id = 0; + TestNDCGGPair(&ctx); +} + void TestGPUMakePair() { Context ctx; ctx.gpu_id = 0; @@ -107,6 +119,12 @@ void TestGPUMakePair() { TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); } +TEST(LambdaRank, GPUUnbiasedNDCG) { + Context ctx; + ctx.gpu_id = 0; + TestUnbiasedNDCG(&ctx); +} + template void RankItemCountImpl(std::vector const &sorted_items, CountFunctor f, std::uint32_t find_val, std::uint32_t exp_val) { diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h index 8dd238d2bad5..aebe3ad54f3e 100644 --- a/tests/cpp/objective/test_lambdarank_obj.h +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright (c) 2023, XGBoost Contributors */ #ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ #define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ @@ -18,6 +18,25 @@ #include "../helpers.h" // for EmptyDMatrix namespace xgboost::obj { +inline void TestNDCGJsonIO(Context const* ctx) { + std::unique_ptr obj{ObjFunction::Create("rank:ndcg", ctx)}; + + obj->Configure(Args{}); + Json j_obj{Object()}; + obj->SaveConfig(&j_obj); + + ASSERT_EQ(get(j_obj["name"]), "rank:ndcg"); + auto const& j_param = j_obj["lambdarank_param"]; + + ASSERT_EQ(get(j_param["ndcg_exp_gain"]), "1"); + ASSERT_EQ(get(j_param["lambdarank_num_pair_per_sample"]), + std::to_string(ltr::LambdaRankParam::NotSet())); +} + +void TestNDCGGPair(Context const* ctx); + +void TestUnbiasedNDCG(Context const* ctx); + /** * \brief Initialize test data for make pair tests. */ diff --git a/tests/cpp/objective/test_ranking_obj.cc b/tests/cpp/objective/test_ranking_obj.cc index a007750e3d81..2072f530e8da 100644 --- a/tests/cpp/objective/test_ranking_obj.cc +++ b/tests/cpp/objective/test_ranking_obj.cc @@ -35,24 +35,6 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) { ASSERT_NO_THROW(obj->DefaultEvalMetric()); } -TEST(Objective, DeclareUnifiedTest(NDCG_JsonIO)) { - xgboost::Context ctx; - ctx.UpdateAllowUnknown(Args{}); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)}; - - obj->Configure(Args{}); - Json j_obj {Object()}; - obj->SaveConfig(&j_obj); - - ASSERT_EQ(get(j_obj["name"]), "rank:ndcg");; - - auto const& j_param = j_obj["lambda_rank_param"]; - - ASSERT_EQ(get(j_param["num_pairsample"]), "1"); - ASSERT_EQ(get(j_param["fix_list_weight"]), "0"); -} - TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) { std::vector> args; xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); @@ -71,33 +53,6 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) { ASSERT_NO_THROW(obj->DefaultEvalMetric()); } -TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "rank:ndcg"); - - // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {0.7f, -0.7f, 0.0f, 0.0f}, - {0.74f, 0.74f, 0.0f, 0.0f}); - - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.35f, -0.35f, 0.35f, -0.35f}, - {0.368f, 0.368f, 0.368f, 0.368f}); - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) { std::vector> args; xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index 540560c1f64b..cd40b49284f6 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -89,62 +89,6 @@ TEST(Objective, RankSegmentSorterAscendingTest) { 5, 4, 6}); } -TEST(Objective, NDCGLambdaWeightComputerTest) { - std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels - 7.8f, 5.01f, 6.96f, - 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}; - dh::device_vector dlabels(hlabels); - - auto segment_label_sorter = RankSegmentSorterTestImpl( - {0, 4, 7, 12}, // Groups - hlabels, - {4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels - 7.8f, 6.96f, 5.01f, - 11.4f, 11.4f, 10.3f, 9.45f, 8.7f}, - {3, 0, 2, 1, // Expected original positions - 4, 6, 5, - 9, 11, 7, 10, 8}); - - // Created segmented predictions for the labels from above - std::vector hpreds{-9.78f, 24.367f, 0.908f, -11.47f, - -1.03f, -2.79f, -3.1f, - 104.22f, 103.1f, -101.7f, 100.5f, 45.1f}; - dh::device_vector dpreds(hpreds); - - xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(), - dlabels.data().get(), - *segment_label_sorter); - - // Where will the predictions move from its current position, if they were sorted - // descendingly? - auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan(); - std::vector hsorted_pred_pos(segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos); - std::vector expected_sorted_pred_pos{2, 0, 1, 3, - 4, 5, 6, - 7, 8, 11, 9, 10}; - EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos); - - // Check group DCG values - std::vector hgroup_dcgs(segment_label_sorter->GetNumGroups()); - dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan()); - std::vector hgroups(segment_label_sorter->GetNumGroups() + 1); - dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan()); - EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups()); - std::vector hsorted_labels(segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan()); - for (size_t i = 0; i < hgroup_dcgs.size(); ++i) { - // Compute group DCG value on CPU and compare - auto gbegin = hgroups[i]; - auto gend = hgroups[i + 1]; - EXPECT_NEAR( - hgroup_dcgs[i], - xgboost::obj::NDCGLambdaWeightComputer::ComputeGroupDCGWeight(&hsorted_labels[gbegin], - gend - gbegin), - 0.01f); - } -} - TEST(Objective, IndexableSortedItemsTest) { std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels 7.8f, 5.01f, 6.96f, diff --git a/tests/python-gpu/test_gpu_eval_metrics.py b/tests/python-gpu/test_gpu_eval_metrics.py index 6d16aa44e1d7..1e9d1a282bcb 100644 --- a/tests/python-gpu/test_gpu_eval_metrics.py +++ b/tests/python-gpu/test_gpu_eval_metrics.py @@ -1,3 +1,4 @@ +import json import sys import pytest @@ -36,19 +37,16 @@ def test_roc_auc_ltr(self, n_samples): Xy = xgboost.DMatrix(X, y, group=group) - cpu = xgboost.train( + booster = xgboost.train( {"tree_method": "hist", "eval_metric": "auc", "objective": "rank:ndcg"}, Xy, num_boost_round=10, ) - cpu_auc = float(cpu.eval(Xy).split(":")[1]) - - gpu = xgboost.train( - {"tree_method": "gpu_hist", "eval_metric": "auc", "objective": "rank:ndcg"}, - Xy, - num_boost_round=10, - ) - gpu_auc = float(gpu.eval(Xy).split(":")[1]) + cpu_auc = float(booster.eval(Xy).split(":")[1]) + booster.set_param({"gpu_id": "0"}) + assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0" + gpu_auc = float(booster.eval(Xy).split(":")[1]) + assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0" np.testing.assert_allclose(cpu_auc, gpu_auc)