Skip to content

Commit

Permalink
Pass tree in the necessary functions so it can be used in ComputeBest…
Browse files Browse the repository at this point in the history
…SplitForFeature.
  • Loading branch information
Charles Auguste committed Mar 24, 2020
1 parent 001bafd commit 37757e8
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 37 deletions.
10 changes: 5 additions & 5 deletions src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}

template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
TREELEARNER_T::ConstructHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
// construct local histograms
Expand All @@ -163,11 +163,11 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(),
block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
this->FindBestSplitsFromHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
this->col_sampler_.is_feature_used_bytree(), true, tree);
}

template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool, const Tree* tree) {
std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features =
Expand All @@ -194,7 +194,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->leaf_index()),
this->smaller_leaf_splits_.get(),
&smaller_bests_per_thread[tid]);
&smaller_bests_per_thread[tid], tree);

// only root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) continue;
Expand All @@ -208,7 +208,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->leaf_index()),
this->larger_leaf_splits_.get(),
&larger_bests_per_thread[tid]);
&larger_bests_per_thread[tid], tree);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down
6 changes: 4 additions & 2 deletions src/treelearner/feature_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}

template <typename TREELEARNER_T>
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(
const std::vector<int8_t> &is_feature_used, bool use_subtract,
const Tree *tree) {
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
Expand Down
10 changes: 5 additions & 5 deletions src/treelearner/parallel_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class FeatureParallelTreeLearner: public TREELEARNER_T {

protected:
void BeforeTrain() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) override;

private:
/*! \brief rank of local machine */
Expand Down Expand Up @@ -59,8 +59,8 @@ class DataParallelTreeLearner: public TREELEARNER_T {

protected:
void BeforeTrain() override;
void FindBestSplits() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void FindBestSplits(const Tree* tree) override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
Expand Down Expand Up @@ -114,8 +114,8 @@ class VotingParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestSplits() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void FindBestSplits(const Tree* tree) override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;

inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
Expand Down
26 changes: 12 additions & 14 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before finding best split
if (BeforeFindBestSplit(tree_prt, left_leaf, right_leaf)) {
// find best threshold for every feature
FindBestSplits();
FindBestSplits(tree_prt);
}
// Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
Expand Down Expand Up @@ -301,7 +301,7 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true;
}

void SerialTreeLearner::FindBestSplits() {
void SerialTreeLearner::FindBestSplits(const Tree* tree) {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
Expand All @@ -315,7 +315,7 @@ void SerialTreeLearner::FindBestSplits() {
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
ConstructHistograms(is_feature_used, use_subtract);
FindBestSplitsFromHistograms(is_feature_used, use_subtract);
FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}

void SerialTreeLearner::ConstructHistograms(
Expand Down Expand Up @@ -344,7 +344,7 @@ void SerialTreeLearner::ConstructHistograms(
}

void SerialTreeLearner::FindBestSplitsFromHistograms(
const std::vector<int8_t>& is_feature_used, bool use_subtract) {
const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) {
Common::FunctionTimer fun_timer(
"SerialTreeLearner::FindBestSplitsFromHistograms", global_timer);
std::vector<SplitInfo> smaller_best(share_state_->num_threads);
Expand All @@ -370,7 +370,7 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
real_fidx,
smaller_node_used_features[feature_index],
smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_.get(), &smaller_best[tid]);
smaller_leaf_splits_.get(), &smaller_best[tid], tree);

// only has root leaf
if (larger_leaf_splits_ == nullptr ||
Expand All @@ -392,7 +392,7 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(
real_fidx,
larger_node_used_features[feature_index],
larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_.get(), &larger_best[tid]);
larger_leaf_splits_.get(), &larger_best[tid], tree);

OMP_LOOP_EX_END();
}
Expand Down Expand Up @@ -428,7 +428,7 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
FindBestSplits();
FindBestSplits(tree);
}
// then, compute own splits
SplitInfo left_split;
Expand Down Expand Up @@ -631,7 +631,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_per_leaf_);
// update leave outputs if needed
for (auto leaf : leaves_need_update) {
RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf]);
RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf], tree);
}
}

Expand Down Expand Up @@ -678,7 +678,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
void SerialTreeLearner::ComputeBestSplitForFeature(
FeatureHistogram* histogram_array_, int feature_index, int real_fidx,
bool is_feature_used, int num_data, const LeafSplits* leaf_splits,
SplitInfo* best_split) {
SplitInfo* best_split, const Tree* tree) {
if (!is_feature_used) {
return;
}
Expand All @@ -697,7 +697,7 @@ void SerialTreeLearner::ComputeBestSplitForFeature(
}
}

void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split, const Tree* tree) {
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning(
Expand Down Expand Up @@ -725,10 +725,8 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
}
const int tid = omp_get_thread_num();
int real_fidx = train_data_->RealFeatureIndex(feature_index);
ComputeBestSplitForFeature(
histogram_array_, feature_index, real_fidx,
true,
num_data, &leaf_splits, &bests[tid]);
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true,
num_data, &leaf_splits, &bests[tid], tree);

OMP_LOOP_EX_END();
}
Expand Down
8 changes: 4 additions & 4 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ class SerialTreeLearner: public TreeLearner {
int feature_index, int real_fidx,
bool is_feature_used, int num_data,
const LeafSplits* leaf_splits,
SplitInfo* best_split);
SplitInfo* best_split, const Tree* tree);

void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);

void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split);
void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split, const Tree* tree);

/*!
* \brief Some initial works before training
Expand All @@ -134,11 +134,11 @@ class SerialTreeLearner: public TreeLearner {
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);

virtual void FindBestSplits();
virtual void FindBestSplits(const Tree* tree);

virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);

virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree);

/*!
* \brief Partition tree and data according best split.
Expand Down
14 changes: 7 additions & 7 deletions src/treelearner/voting_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vec
}

template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
// use local data to find local best splits
std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static)
Expand Down Expand Up @@ -279,7 +279,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
true, this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_splits_.get(),
&smaller_bestsplit_per_features[feature_index]);
&smaller_bestsplit_per_features[feature_index], tree);
// only has root leaf
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }

Expand All @@ -293,7 +293,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->larger_leaf_histogram_array_, feature_index, real_feature_index,
true, this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_.get(),
&larger_bestsplit_per_features[feature_index]);
&larger_bestsplit_per_features[feature_index], tree);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -344,11 +344,11 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(),
output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);

this->FindBestSplitsFromHistograms(is_feature_used, false);
this->FindBestSplitsFromHistograms(is_feature_used, false, tree);
}

template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool, const Tree* tree) {
std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features =
Expand Down Expand Up @@ -376,7 +376,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
smaller_leaf_histogram_array_global_.get(), feature_index,
real_feature_index, smaller_node_used_features[feature_index],
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->leaf_index()),
smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid]);
smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid], tree);
}

if (larger_is_feature_aggregated_[feature_index]) {
Expand All @@ -392,7 +392,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
real_feature_index,
larger_node_used_features[feature_index],
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid]);
larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid], tree);
}
OMP_LOOP_EX_END();
}
Expand Down

0 comments on commit 37757e8

Please sign in to comment.