-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to optimize for NDCG at a given truncation level #3425
Changes from 2 commits
16955e0
d938551
b2736e6
9e3cc2b
c7890dc
7f36817
4d7183b
c15423f
8e98bd4
a06485a
c51fb3a
7bfc652
af2fe37
b6bd92f
e73d3db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,35 +163,30 @@ class LambdarankNDCG : public RankingObjective { | |
} | ||
const double worst_score = score[sorted_idx[worst_idx]]; | ||
double sum_lambdas = 0.0; | ||
// start accmulate lambdas by pairs | ||
for (data_size_t i = 0; i < cnt; ++i) { | ||
const data_size_t high = sorted_idx[i]; | ||
const int high_label = static_cast<int>(label[high]); | ||
const double high_score = score[high]; | ||
if (high_score == kMinScore) { | ||
continue; | ||
} | ||
const double high_label_gain = label_gain_[high_label]; | ||
const double high_discount = DCGCalculator::GetDiscount(i); | ||
double high_sum_lambda = 0.0; | ||
double high_sum_hessian = 0.0; | ||
for (data_size_t j = 0; j < cnt; ++j) { | ||
// skip same data | ||
if (i == j) { | ||
continue; | ||
} | ||
const data_size_t low = sorted_idx[j]; | ||
// start accmulate lambdas by pairs that contain at least one document above truncation level | ||
for (data_size_t i = 0; i < cnt - 1 && i < truncation_level_; ++i) { | ||
for (data_size_t j = i + 1; j < cnt; ++j) { | ||
metpavel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// skip pairs with the same labels | ||
if (label[sorted_idx[i]] == label[sorted_idx[j]]) { continue; } | ||
|
||
const data_size_t high_rank = label[sorted_idx[i]] > label[sorted_idx[j]] ? i : j; | ||
const data_size_t high = sorted_idx[high_rank]; | ||
const int high_label = static_cast<int>(label[high]); | ||
const double high_score = score[high]; | ||
const double high_label_gain = label_gain_[high_label]; | ||
const double high_discount = DCGCalculator::GetDiscount(high_rank); | ||
|
||
const data_size_t low_rank = label[sorted_idx[i]] > label[sorted_idx[j]] ? j : i; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the additional branching? get high_rank and low_rank by one if. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const data_size_t low = sorted_idx[low_rank]; | ||
const int low_label = static_cast<int>(label[low]); | ||
const double low_score = score[low]; | ||
// only consider pair with different label | ||
if (high_label <= low_label || low_score == kMinScore) { | ||
continue; | ||
} | ||
const double low_label_gain = label_gain_[low_label]; | ||
const double low_discount = DCGCalculator::GetDiscount(low_rank); | ||
|
||
if (high_score == kMinScore || low_score == kMinScore) { continue; } | ||
|
||
metpavel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const double delta_score = high_score - low_score; | ||
|
||
const double low_label_gain = label_gain_[low_label]; | ||
const double low_discount = DCGCalculator::GetDiscount(j); | ||
// get dcg gap | ||
const double dcg_gap = high_label_gain - low_label_gain; | ||
// get discount of this pair | ||
|
@@ -208,16 +203,13 @@ class LambdarankNDCG : public RankingObjective { | |
// update | ||
p_lambda *= -sigmoid_ * delta_pair_NDCG; | ||
p_hessian *= sigmoid_ * sigmoid_ * delta_pair_NDCG; | ||
high_sum_lambda += p_lambda; | ||
high_sum_hessian += p_hessian; | ||
lambdas[low] -= static_cast<score_t>(p_lambda); | ||
hessians[low] += static_cast<score_t>(p_hessian); | ||
lambdas[high] += static_cast<score_t>(p_lambda); | ||
hessians[high] += static_cast<score_t>(p_hessian); | ||
// lambda is negative, so use minus to accumulate | ||
sum_lambdas -= 2 * p_lambda; | ||
} | ||
// update | ||
lambdas[high] += static_cast<score_t>(high_sum_lambda); | ||
hessians[high] += static_cast<score_t>(high_sum_hessian); | ||
} | ||
if (norm_ && sum_lambdas > 0) { | ||
double norm_factor = std::log2(1 + sum_lambdas) / sum_lambdas; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. /gha run-valgrind There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. /gha run r-valgrind There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ^ this comment started this run checking that these changes pass our valgrind tests: https://github.com/microsoft/LightGBM/runs/1296338580?check_suite_focus=true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this didn't show any new issues...but I'll run it again once my fixes are merged into this PR: #3425 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. /gha run-valgrind There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. /gha run r-valgrind There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this passes, I'll come back and approve: https://github.com/microsoft/LightGBM/runs/1310275649?check_suite_focus=true |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this start from '0' ? if not, why?
updated:
I see. never mind.