Skip to content

Commit

Permalink
Rework the NDCG objective. (#9015)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Apr 18, 2023
1 parent ba9d24f commit ef13dd3
Show file tree
Hide file tree
Showing 15 changed files with 1,082 additions and 351 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \
Expand Down
31 changes: 16 additions & 15 deletions src/common/math.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
/*!
* Copyright 2015 by Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file math.h
* \brief additional math utils
* \author Tianqi Chen
*/
#ifndef XGBOOST_COMMON_MATH_H_
#define XGBOOST_COMMON_MATH_H_

#include <xgboost/base.h>
#include <xgboost/base.h> // for XGBOOST_DEVICE

#include <algorithm>
#include <cmath>
#include <limits>
#include <utility>
#include <vector>
#include <algorithm> // for max
#include <cmath> // for exp, abs, log, lgamma
#include <limits> // for numeric_limits
#include <type_traits> // for is_floating_point, conditional, is_signed, is_same, declval, enable_if
#include <utility> // for pair

namespace xgboost {
namespace common {

template <typename T> XGBOOST_DEVICE T Sqr(T const &w) { return w * w; }

/*!
* \brief calculate the sigmoid of the input.
* \param x input parameter
Expand All @@ -30,9 +33,11 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
return y;
}

template <typename T>
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }

XGBOOST_DEVICE inline double Sigmoid(double x) {
auto denom = std::exp(-x) + 1.0;
auto y = 1.0 / denom;
return y;
}
/*!
* \brief Equality test for both integer and floating point.
*/
Expand Down Expand Up @@ -134,10 +139,6 @@ inline static bool CmpFirst(const std::pair<float, unsigned> &a,
const std::pair<float, unsigned> &b) {
return a.first > b.first;
}
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
const std::pair<float, unsigned> &b) {
return a.second > b.second;
}

// Redefined here to workaround a VC bug that doesn't support overloading for integer
// types.
Expand Down
6 changes: 3 additions & 3 deletions src/common/ranking_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
// pairs
// should be accessed by getter for auto configuration.
// nolint so that we can keep the string name.
PairMethod lambdarank_pair_method{PairMethod::kMean}; // NOLINT
PairMethod lambdarank_pair_method{PairMethod::kTopK}; // NOLINT
std::size_t lambdarank_num_pair_per_sample{NotSet()}; // NOLINT

public:
static constexpr position_t NotSet() { return std::numeric_limits<position_t>::max(); }

// unbiased
bool lambdarank_unbiased{false};
double lambdarank_bias_norm{2.0};
double lambdarank_bias_norm{1.0};
// ndcg
bool ndcg_exp_gain{true};

Expand Down Expand Up @@ -135,7 +135,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
.set_default(false)
.describe("Unbiased lambda mart. Use extended IPW to debias click position");
DMLC_DECLARE_FIELD(lambdarank_bias_norm)
.set_default(2.0)
.set_default(1.0)
.set_lower_bound(0.0)
.describe("Lp regularization for unbiased lambdarank.");
DMLC_DECLARE_FIELD(ndcg_exp_gain)
Expand Down
Loading

0 comments on commit ef13dd3

Please sign in to comment.