Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix zero bin in categorical split #3305

Merged
merged 10 commits into from
Aug 15, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/LightGBM/bin.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,9 @@ class MultiValBin {

inline uint32_t BinMapper::ValueToBin(double value) const {
if (std::isnan(value)) {
if (missing_type_ == MissingType::NaN) {
if (bin_type_ == BinType::CategoricalBin) {
return 0;
} else if (missing_type_ == MissingType::NaN) {
return num_bin_ - 1;
} else {
value = 0.0f;
Expand All @@ -482,12 +484,12 @@ inline uint32_t BinMapper::ValueToBin(double value) const {
int int_value = static_cast<int>(value);
// convert negative value to NaN bin
if (int_value < 0) {
return num_bin_ - 1;
return 0;
}
if (categorical_2_bin_.count(int_value)) {
return categorical_2_bin_.at(int_value);
} else {
return num_bin_ - 1;
return 0;
}
}
}
Expand Down
43 changes: 15 additions & 28 deletions src/io/bin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ namespace LightGBM {
}
}
}
num_bin_ = 0;
int rest_cnt = static_cast<int>(total_sample_cnt - na_cnt);
if (rest_cnt > 0) {
const int SPARSE_RATIO = 100;
Expand All @@ -449,23 +448,25 @@ namespace LightGBM {
}
// sort by counts
Common::SortForPair<int, int>(&counts_int, &distinct_values_int, 0, true);
// avoid first bin is zero
if (distinct_values_int[0] == 0) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
}
// will ignore the categorical of small counts
int cut_cnt = static_cast<int>((total_sample_cnt - na_cnt) * 0.99f);
int cut_cnt = static_cast<int>(
Common::RoundInt((total_sample_cnt - na_cnt) * 0.99f));
size_t cur_cat = 0;
categorical_2_bin_.clear();
bin_2_categorical_.clear();
int used_cnt = 0;
max_bin = std::min(static_cast<int>(distinct_values_int.size()), max_bin);
int distinct_cnt = static_cast<int>(distinct_values_int.size());
if (na_cnt > 0) {
++distinct_cnt;
}
max_bin = std::min(distinct_cnt, max_bin);
cnt_in_bin.clear();

// Push the dummy bin for NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = 0;
cnt_in_bin.push_back(0);
num_bin_ = 1;
while (cur_cat < distinct_values_int.size()
&& (used_cnt < cut_cnt || num_bin_ < max_bin)) {
if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) {
Expand All @@ -478,21 +479,14 @@ namespace LightGBM {
++num_bin_;
++cur_cat;
}
// need an additional bin for NaN
if (cur_cat == distinct_values_int.size() && na_cnt > 0) {
// use -1 to represent NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = num_bin_;
cnt_in_bin.push_back(0);
++num_bin_;
}
// Use MissingType::None to represent this bin contains all categoricals
if (cur_cat == distinct_values_int.size() && na_cnt == 0) {
missing_type_ = MissingType::None;
} else {
missing_type_ = MissingType::NaN;
}
cnt_in_bin.back() += static_cast<int>(total_sample_cnt - used_cnt);
// fix count of NaN bin
cnt_in_bin[0] = static_cast<int>(total_sample_cnt - used_cnt);
}
}

Expand All @@ -511,13 +505,6 @@ namespace LightGBM {
default_bin_ = ValueToBin(0);
most_freq_bin_ =
static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin));
if (bin_type_ == BinType::CategoricalBin) {
if (most_freq_bin_ == 0) {
CHECK_GT(num_bin_, 1);
// FIXME: how to enable `most_freq_bin_ = 0` for categorical features
most_freq_bin_ = 1;
}
}
const double max_sparse_rate =
static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
// When most_freq_bin_ != default_bin_, there are some additional data loading costs.
Expand Down
6 changes: 4 additions & 2 deletions src/io/dense_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ class DenseBin : public Bin {
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
int8_t offset = most_freq_bin == 0 ? 1 : 0;
if (most_freq_bin > 0 &&
Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
Expand All @@ -330,7 +332,7 @@ class DenseBin : public Bin {
} else if (!USE_MIN_BIN && bin == 0) {
default_indices[(*default_count)++] = idx;
} else if (Common::FindInBitset(threshold, num_threahold,
bin - min_bin)) {
bin - min_bin + offset)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
Expand Down
5 changes: 3 additions & 2 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ class SparseBin : public Bin {
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
int8_t offset = most_freq_bin == 0 ? 1 : 0;
if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
Expand All @@ -376,7 +377,7 @@ class SparseBin : public Bin {
} else if (!USE_MIN_BIN && bin == 0) {
default_indices[(*default_count)++] = idx;
} else if (Common::FindInBitset(threshold, num_threahold,
bin - min_bin)) {
bin - min_bin + offset)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
Expand Down
38 changes: 19 additions & 19 deletions src/treelearner/feature_histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ class FeatureHistogram {
}

double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
const int8_t offset = meta_->offset;
const int bin_start = 1 - offset;
const int bin_end = meta_->num_bin - offset;
int used_bin = -1;

std::vector<int> sorted_idx;
double l2 = meta_->config->lambda_l2;
Expand All @@ -312,11 +314,11 @@ class FeatureHistogram {
int rand_threshold = 0;
if (use_onehot) {
if (USE_RAND) {
if (used_bin > 0) {
rand_threshold = meta_->rand.NextInt(0, used_bin);
if (bin_end - bin_start > 0) {
rand_threshold = meta_->rand.NextInt(bin_start, bin_end);
}
}
for (int t = 0; t < used_bin; ++t) {
for (int t = bin_start; t < bin_end; ++t) {
const auto grad = GET_GRAD(data_, t);
const auto hess = GET_HESS(data_, t);
data_size_t cnt =
Expand Down Expand Up @@ -366,7 +368,7 @@ class FeatureHistogram {
}
}
} else {
for (int i = 0; i < used_bin; ++i) {
for (int i = bin_start; i < bin_end; ++i) {
if (Common::RoundInt(GET_HESS(data_, i) * cnt_factor) >=
meta_->config->cat_smooth) {
sorted_idx.push_back(i);
Expand All @@ -379,11 +381,11 @@ class FeatureHistogram {
auto ctr_fun = [this](double sum_grad, double sum_hess) {
return (sum_grad) / (sum_hess + meta_->config->cat_smooth);
};
std::sort(sorted_idx.begin(), sorted_idx.end(),
[this, &ctr_fun](int i, int j) {
return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) <
ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j));
});
std::stable_sort(
sorted_idx.begin(), sorted_idx.end(), [this, &ctr_fun](int i, int j) {
return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) <
ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j));
});

std::vector<int> find_direction(1, 1);
std::vector<int> start_position(1, 0);
Expand Down Expand Up @@ -489,19 +491,19 @@ class FeatureHistogram {
if (use_onehot) {
output->num_cat_threshold = 1;
output->cat_threshold =
std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold));
std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold + offset));
} else {
output->num_cat_threshold = best_threshold + 1;
output->cat_threshold =
std::vector<uint32_t>(output->num_cat_threshold);
if (best_dir == 1) {
for (int i = 0; i < output->num_cat_threshold; ++i) {
auto t = sorted_idx[i];
auto t = sorted_idx[i] + offset;
output->cat_threshold[i] = t;
}
} else {
for (int i = 0; i < output->num_cat_threshold; ++i) {
auto t = sorted_idx[used_bin - 1 - i];
auto t = sorted_idx[used_bin - 1 - i] + offset;
output->cat_threshold[i] = t;
}
}
Expand Down Expand Up @@ -649,16 +651,14 @@ class FeatureHistogram {
double gain_shift = GetLeafGainGivenOutput<true>(
sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
if (threshold >= static_cast<uint32_t>(used_bin)) {
if (threshold >= static_cast<uint32_t>(meta_->num_bin) || threshold < meta_->offset) {
guolinke marked this conversation as resolved.
Show resolved Hide resolved
output->gain = kMinScore;
Log::Warning("Invalid categorical threshold split");
return;
}
const double cnt_factor = num_data / sum_hessian;
const auto grad = GET_GRAD(data_, threshold);
const auto hess = GET_HESS(data_, threshold);
const auto grad = GET_GRAD(data_, threshold - meta_->offset);
const auto hess = GET_HESS(data_, threshold - meta_->offset);
data_size_t cnt =
static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));

Expand Down