Skip to content

Commit

Permalink
some refine
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Feb 23, 2020
1 parent 06b84d4 commit 0c9069e
Showing 1 changed file with 6 additions and 27 deletions.
33 changes: 6 additions & 27 deletions src/objective/rank_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class RankingObjective : public ObjectiveFunction {
explicit RankingObjective(const Config& config)
: seed_(config.objective_seed) {}

explicit RankingObjective(const std::vector<std::string>&)
: seed_(0) {}
explicit RankingObjective(const std::vector<std::string>&) : seed_(0) {}

~RankingObjective() {}

Expand Down Expand Up @@ -295,28 +294,8 @@ class RankXENDCG : public RankingObjective {

void Init(const Metadata& metadata, data_size_t num_data) override {
RankingObjective::Init(metadata, num_data);
for (data_size_t i = 0;
i < (num_queries_ + query_block_size_ - 1) / query_block_size_; ++i) {
rands_.emplace_back(seed_ + num_queries_ + i);
}
}

void GetGradients(const double* score, score_t* gradients,
score_t* hessians) const override {
#pragma omp parallel for schedule(static, query_block_size_)
for (data_size_t i = 0; i < num_queries_; ++i) {
const data_size_t start = query_boundaries_[i];
const data_size_t cnt = query_boundaries_[i + 1] - query_boundaries_[i];
GetGradientsForOneQuery(i, cnt, label_ + start, score + start,
gradients + start, hessians + start);
if (weights_ != nullptr) {
for (data_size_t j = 0; j < cnt; ++j) {
gradients[start + j] =
static_cast<score_t>(gradients[start + j] * weights_[start + j]);
hessians[start + j] =
static_cast<score_t>(hessians[start + j] * weights_[start + j]);
}
}
rands_.emplace_back(seed_ + i);
}
}

Expand All @@ -332,11 +311,11 @@ class RankXENDCG : public RankingObjective {
std::vector<double> l1s(cnt);
double sum_labels = 0;
for (data_size_t i = 0; i < cnt; ++i) {
l1s[i] = Phi(label[i], rands_[query_id / query_block_size_].NextFloat());
l1s[i] = Phi(label[i], rands_[query_id].NextFloat());
sum_labels += l1s[i];
}
sum_labels = Common::Sign(sum_labels) *
std::max<double>(kEpsilon, std::fabs(sum_labels));
// sum_labels will always be positive number
sum_labels = std::max<double>(kEpsilon, sum_labels);
// Approximate gradients and inverse Hessian.
// First order terms.
double sum_l1 = 0.0f;
Expand All @@ -345,6 +324,7 @@ class RankXENDCG : public RankingObjective {
sum_l1 += l1s[i];
}
if (cnt <= 1) {
// when cnt <= 1, the l2 and l3 are zeros
for (data_size_t i = 0; i < cnt; ++i) {
lambdas[i] = static_cast<score_t>(l1s[i]);
hessians[i] = static_cast<score_t>(rho[i] * (1.0 - rho[i]));
Expand Down Expand Up @@ -373,7 +353,6 @@ class RankXENDCG : public RankingObjective {
const char* GetName() const override { return "rank_xendcg"; }

private:
const int query_block_size_ = 128;
mutable std::vector<Random> rands_;
};

Expand Down

0 comments on commit 0c9069e

Please sign in to comment.