Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add weight in tree model output #2269

Merged
merged 3 commits into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@
" 'split_gain',\n",
" 'internal_value',\n",
" 'internal_count',\n",
" 'leaf_count'],\n",
" 'internal_weight',\n",
" 'leaf_count',\n",
" 'leaf_weight'],\n",
" value=['None']),\n",
" precision=(0, 10))\n",
" tree = None\n",
Expand All @@ -382,7 +384,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
"version": "3.7.1"
},
"varInspector": {
"cols": {
Expand Down
23 changes: 18 additions & 5 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,17 @@ class Tree {
* \param right_value Model Right child output
* \param left_cnt Count of left child
* \param right_cnt Count of right child
* \param left_weight Weight of left child
* \param right_weight Weight of right child
* \param gain Split gain
* \param missing_type missing type
* \param default_left default direction for missing value
* \return The index of new leaf.
*/
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left);
int left_cnt, int right_cnt, double left_weight, double right_weight,
float gain, MissingType missing_type, bool default_left);

/*!
* \brief Performing a split on tree leaves, with categorical feature
Expand All @@ -72,12 +75,14 @@ class Tree {
* \param right_value Model Right child output
* \param left_cnt Count of left child
* \param right_cnt Count of right child
* \param left_weight Weight of left child
* \param right_weight Weight of right child
* \param gain Split gain
* \return The index of new leaf.
*/
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type);
int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type);

/*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
Expand Down Expand Up @@ -297,8 +302,8 @@ class Tree {
}
}

inline void Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt, float gain);
inline void Split(int leaf, int feature, int real_feature, double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain);
/*!
* \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record
Expand Down Expand Up @@ -383,10 +388,14 @@ class Tree {
std::vector<int> leaf_parent_;
/*! \brief Output of leaves */
std::vector<double> leaf_value_;
/*! \brief weight of leaves */
std::vector<double> leaf_weight_;
/*! \brief DataCount of leaves */
std::vector<int> leaf_count_;
/*! \brief Output of non-leaf nodes */
std::vector<double> internal_value_;
/*! \brief weight of non-leaf nodes */
std::vector<double> internal_weight_;
/*! \brief DataCount of non-leaf nodes */
std::vector<int> internal_count_;
/*! \brief Depth for leaves */
Expand All @@ -396,7 +405,8 @@ class Tree {
};

inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt, float gain) {
double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain) {
int new_node_idx = num_leaves_ - 1;
// update parent info
int parent = leaf_parent_[leaf];
Expand All @@ -420,11 +430,14 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change
internal_weight_[new_node_idx] = leaf_weight_[leaf];
internal_value_[new_node_idx] = leaf_value_[leaf];
internal_count_[new_node_idx] = left_cnt + right_cnt;
leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
leaf_weight_[leaf] = left_weight;
leaf_count_[leaf] = left_cnt;
leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
leaf_weight_[num_leaves_] = right_weight;
leaf_count_[num_leaves_] = right_cnt;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
Expand Down
10 changes: 7 additions & 3 deletions python-package/lightgbm/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def add(root, parent=None, decision=None):
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value'}:
if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
Expand All @@ -409,6 +409,8 @@ def add(root, parent=None, decision=None):
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
if 'leaf_weight' in show_info:
label += r'\nleaf_weight: {0}'.format(_float2str(root['leaf_weight'], precision))
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)
Expand Down Expand Up @@ -438,7 +440,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
The index of a target tree to convert.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Expand Down Expand Up @@ -515,7 +518,8 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
Figure size.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace LightGBM {

const std::string kModelVersion = "v2";
const std::string kModelVersion = "v3";
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved

std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
std::stringstream str_buf;
Expand Down
33 changes: 28 additions & 5 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ Tree::Tree(int max_leaves)
split_gain_.resize(max_leaves_ - 1);
leaf_parent_.resize(max_leaves_);
leaf_value_.resize(max_leaves_);
leaf_weight_.resize(max_leaves_);
leaf_count_.resize(max_leaves_);
internal_value_.resize(max_leaves_ - 1);
internal_weight_.resize(max_leaves_ - 1);
internal_count_.resize(max_leaves_ - 1);
leaf_depth_.resize(max_leaves_);
// root is in the depth 0
leaf_depth_[0] = 0;
num_leaves_ = 1;
leaf_value_[0] = 0.0f;
leaf_weight_[0] = 0.0f;
leaf_parent_[0] = -1;
shrinkage_ = 1.0f;
num_cat_ = 0;
Expand All @@ -47,8 +50,8 @@ Tree::~Tree() {

int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Expand All @@ -68,8 +71,8 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,

int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
Expand Down Expand Up @@ -221,10 +224,14 @@ std::string Tree::ToString() const {
<< Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
str_buf << "leaf_value="
<< Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
str_buf << "leaf_weight="
<< Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
str_buf << "leaf_count="
<< Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
str_buf << "internal_value="
<< Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
str_buf << "internal_weight="
<< Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
str_buf << "internal_count="
<< Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
if (num_cat_ > 0) {
Expand Down Expand Up @@ -294,6 +301,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"missing_type\":\"NaN\"," << '\n';
}
str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
Expand All @@ -304,6 +312,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "{" << '\n';
str_buf << "\"leaf_index\":" << index << "," << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
str_buf << "}";
}
Expand Down Expand Up @@ -472,7 +481,7 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
Tree::Tree(const char* str, size_t* used_len) {
auto p = str;
std::unordered_map<std::string, std::string> key_vals;
const int max_num_line = 15;
const int max_num_line = 17;
int read_line = 0;
while (read_line < max_num_line) {
if (*p == '\r' || *p == '\n') break;
Expand Down Expand Up @@ -557,6 +566,20 @@ Tree::Tree(const char* str, size_t* used_len) {
internal_value_.resize(num_leaves_ - 1);
}

if (key_vals.count("internal_weight")) {
internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
}
else {
internal_weight_.resize(num_leaves_ - 1);
}

if (key_vals.count("leaf_weight")) {
leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
}
else {
leaf_weight_.resize(num_leaves_);
}

if (key_vals.count("leaf_count")) {
leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
Expand Down
8 changes: 8 additions & 0 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int*
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left);
Expand Down Expand Up @@ -711,6 +713,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int*
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
Expand Down Expand Up @@ -792,6 +796,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
Expand All @@ -815,6 +821,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
Expand Down
4 changes: 3 additions & 1 deletion tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_create_tree_digraph(self):
self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)

graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value'],
show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'})
graph.render(view=False)
self.assertIsInstance(graph, graphviz.Digraph)
Expand All @@ -137,8 +137,10 @@ def test_create_tree_digraph(self):
self.assertIn('leaf_index', graph_body)
self.assertIn('split_gain', graph_body)
self.assertIn('internal_value', graph_body)
self.assertIn('internal_weight', graph_body)
self.assertNotIn('internal_count', graph_body)
self.assertNotIn('leaf_count', graph_body)
self.assertNotIn('leaf_weight', graph_body)

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self):
Expand Down