Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Feb 23, 2020
1 parent df70292 commit 8a9dd46
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 150 deletions.
24 changes: 13 additions & 11 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ Core Parameters

- ``lambdarank``, `lambdarank <https://papers.nips.cc/paper/2971-learning-to-rank-with-nonsmooth-cost-functions.pdf>`__ objective. `label_gain <#label_gain>`__ can be used to set the gain (weight) of ``int`` label and all values in ``label`` must be smaller than number of elements in ``label_gain``

- ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function. To obtain reproducible results, you should disable parallelism by setting ``num_threads`` to 1, aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``
- ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function. aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``.

- ``rank_xendcg`` is faster than ``lambdarank`` and achieves the similar performance as ``lambdarank``

- label should be ``int`` type, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)

Expand Down Expand Up @@ -790,6 +792,12 @@ IO Parameters
Objective Parameters
--------------------

- ``objective_seed`` :raw-html:`<a id="objective_seed" title="Permalink to this parameter" href="#objective_seed">&#x1F517;&#xFE0E;</a>`, default = ``5``, type = int

- random seed for objectives, if random process is needed

- used in ``rank_xendcg``

- ``num_class`` :raw-html:`<a id="num_class" title="Permalink to this parameter" href="#num_class">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``num_classes``, constraints: ``num_class > 0``

- used only in ``multi-class`` classification application
Expand Down Expand Up @@ -862,19 +870,19 @@ Objective Parameters

- set this closer to ``1`` to shift towards a **Poisson** distribution

- ``max_position`` :raw-html:`<a id="max_position" title="Permalink to this parameter" href="#max_position">&#x1F517;&#xFE0E;</a>`, default = ``20``, type = int, constraints: ``max_position > 0``
- ``lambdarank_truncation_level`` :raw-html:`<a id="lambdarank_truncation_level" title="Permalink to this parameter" href="#lambdarank_truncation_level">&#x1F517;&#xFE0E;</a>`, default = ``20``, type = int, constraints: ``lambdarank_truncation_level > 0``

- used only in ``lambdarank`` application

- optimizes `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ at this position
- used for truncating the max_ndcg, refer to "truncation level" in the Sec.3 of `LambdaMART paper <https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf>`__ .

- ``lambdamart_norm`` :raw-html:`<a id="lambdamart_norm" title="Permalink to this parameter" href="#lambdamart_norm">&#x1F517;&#xFE0E;</a>`, default = ``true``, type = bool
- ``lambdarank_norm`` :raw-html:`<a id="lambdarank_norm" title="Permalink to this parameter" href="#lambdarank_norm">&#x1F517;&#xFE0E;</a>`, default = ``true``, type = bool

- used only in ``lambdarank`` application

- set this to ``true`` to normalize the lambdas for different queries, and improve the performance for unbalanced data

- set this to ``false`` to enforce the original lambdamart algorithm
- set this to ``false`` to enforce the original lambdarank algorithm

- ``label_gain`` :raw-html:`<a id="label_gain" title="Permalink to this parameter" href="#label_gain">&#x1F517;&#xFE0E;</a>`, default = ``0,1,3,7,15,31,63,...,2^30-1``, type = multi-double

Expand All @@ -884,12 +892,6 @@ Objective Parameters

- separate by ``,``

- ``objective_seed`` :raw-html:`<a id="objective_seed" title="Permalink to this parameter" href="#objective_seed">&#x1F517;&#xFE0E;</a>`, default = ``5``, type = int

- used only in the ``rank_xendcg`` objective

- random seed for objectives

Metric Parameters
-----------------

Expand Down
19 changes: 10 additions & 9 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ struct Config {
// descl2 = label is anything in interval [0, 1]
// desc = ranking application
// descl2 = ``lambdarank``, `lambdarank <https://papers.nips.cc/paper/2971-learning-to-rank-with-nonsmooth-cost-functions.pdf>`__ objective. `label_gain <#label_gain>`__ can be used to set the gain (weight) of ``int`` label and all values in ``label`` must be smaller than number of elements in ``label_gain``
// descl2 = ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function. To obtain reproducible results, you should disable parallelism by setting ``num_threads`` to 1, aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``
// descl2 = ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function. aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``.
// descl2 = ``rank_xendcg`` is faster than ``lambdarank`` and achieves the similar performance as ``lambdarank``
// descl2 = label should be ``int`` type, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)
std::string objective = "regression";

Expand Down Expand Up @@ -692,6 +693,10 @@ struct Config {

#pragma region Objective Parameters

// desc = random seed for objectives, if random process is needed
// desc = used in ``rank_xendcg``
int objective_seed = 5;

// check = >0
// alias = num_classes
// desc = used only in ``multi-class`` classification application
Expand Down Expand Up @@ -750,13 +755,13 @@ struct Config {

// check = >0
// desc = used only in ``lambdarank`` application
// desc = optimizes `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ at this position
int max_position = 20;
// desc = used for truncating the max_ndcg, refer to "truncation level" in the Sec.3 of `LambdaMART paper <https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf>`__ .
int lambdarank_truncation_level = 20;

// desc = used only in ``lambdarank`` application
// desc = set this to ``true`` to normalize the lambdas for different queries, and improve the performance for unbalanced data
// desc = set this to ``false`` to enforce the original lambdamart algorithm
bool lambdamart_norm = true;
// desc = set this to ``false`` to enforce the original lambdarank algorithm
bool lambdarank_norm = true;

// type = multi-double
// default = 0,1,3,7,15,31,63,...,2^30-1
Expand All @@ -765,10 +770,6 @@ struct Config {
// desc = separate by ``,``
std::vector<double> label_gain;

// desc = used only in the ``rank_xendcg`` objective
// desc = random seed for objectives
int objective_seed = 5;

#pragma endregion

#pragma region Metric Parameters
Expand Down
22 changes: 11 additions & 11 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"predict_disable_shape_check",
"convert_model_language",
"convert_model",
"objective_seed",
"num_class",
"is_unbalance",
"scale_pos_weight",
Expand All @@ -267,10 +268,9 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"fair_c",
"poisson_max_delta_step",
"tweedie_variance_power",
"max_position",
"lambdamart_norm",
"lambdarank_truncation_level",
"lambdarank_norm",
"label_gain",
"objective_seed",
"metric",
"metric_freq",
"is_provide_training_metric",
Expand Down Expand Up @@ -513,6 +513,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetString(params, "convert_model", &convert_model);

GetInt(params, "objective_seed", &objective_seed);

GetInt(params, "num_class", &num_class);
CHECK(num_class >0);

Expand Down Expand Up @@ -541,17 +543,15 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
CHECK(tweedie_variance_power >=1.0);
CHECK(tweedie_variance_power <2.0);

GetInt(params, "max_position", &max_position);
CHECK(max_position >0);
GetInt(params, "lambdarank_truncation_level", &lambdarank_truncation_level);
CHECK(lambdarank_truncation_level >0);

GetBool(params, "lambdamart_norm", &lambdamart_norm);
GetBool(params, "lambdarank_norm", &lambdarank_norm);

if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ',');
}

GetInt(params, "objective_seed", &objective_seed);

GetInt(params, "metric_freq", &metric_freq);
CHECK(metric_freq >0);

Expand Down Expand Up @@ -675,6 +675,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[predict_disable_shape_check: " << predict_disable_shape_check << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
str_buf << "[scale_pos_weight: " << scale_pos_weight << "]\n";
Expand All @@ -685,10 +686,9 @@ std::string Config::SaveMembersToString() const {
str_buf << "[fair_c: " << fair_c << "]\n";
str_buf << "[poisson_max_delta_step: " << poisson_max_delta_step << "]\n";
str_buf << "[tweedie_variance_power: " << tweedie_variance_power << "]\n";
str_buf << "[max_position: " << max_position << "]\n";
str_buf << "[lambdamart_norm: " << lambdamart_norm << "]\n";
str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n";
str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
Expand Down
1 change: 0 additions & 1 deletion src/objective/objective_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "binary_objective.hpp"
#include "multiclass_objective.hpp"
#include "rank_objective.hpp"
#include "rank_xendcg_objective.hpp"
#include "regression_objective.hpp"
#include "xentropy_objective.hpp"

Expand Down
Loading

0 comments on commit 8a9dd46

Please sign in to comment.