Skip to content

Commit

Permalink
Support column split in multi-target hist (#9171)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored May 26, 2023
1 parent acd3630 commit 5b69534
Show file tree
Hide file tree
Showing 17 changed files with 386 additions and 96 deletions.
42 changes: 42 additions & 0 deletions src/collective/communicator-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
#pragma once
#include <string>
#include <vector>

#include "communicator.h"

Expand Down Expand Up @@ -224,5 +225,46 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}

template <typename T>
struct AllgatherVResult {
std::vector<std::size_t> offsets;
std::vector<std::size_t> sizes;
std::vector<T> result;
};

/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* We assume each worker has the same number of inputs, but each input may be of a different size.
*
* @param inputs All the inputs from the local worker.
* @param sizes Sizes of each input.
*/
template <typename T>
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
std::vector<std::size_t> const &sizes) {
auto num_inputs = sizes.size();

// Gather the sizes across all workers.
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));

// Calculate input offsets (std::exclusive_scan).
std::vector<std::size_t> offsets(all_sizes.size());
for (auto i = 1; i < offsets.size(); i++) {
offsets[i] = offsets[i - 1] + all_sizes[i - 1];
}

// Gather all the inputs.
auto total_input_size = offsets.back() + all_sizes.back();
std::vector<T> all_inputs(total_input_size);
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
// We cannot use allgather here, since each worker might have a different size.
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());

return {offsets, all_sizes, all_inputs};
}

} // namespace collective
} // namespace xgboost
4 changes: 2 additions & 2 deletions src/common/partition_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class PartitionBuilder {
BitVector* decision_bits, BitVector* missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree[nid].SplitIndex();
bst_feature_t fid = tree.SplitIndex(nid);
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);
auto const& cut_values = gmat.cut.Values();
Expand Down Expand Up @@ -270,7 +270,7 @@ class PartitionBuilder {
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid;
bool default_left = tree[nid].DefaultLeft();
bool default_left = tree.DefaultLeft(nid);

auto pred = [&](auto ridx) {
bool go_left = default_left;
Expand Down
22 changes: 11 additions & 11 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <utility>

#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../data/adapter.h"
#include "categorical.h"
#include "hist_util.h"
Expand Down Expand Up @@ -143,6 +142,7 @@ struct QuantileAllreduce {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches) {
Expand All @@ -168,7 +168,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);

// Gather all column pointers
collective::Allreduce<collective::Operation::kSum>(sketches_scan.data(), sketches_scan.size());
collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size());
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
Expand Down Expand Up @@ -196,7 +196,8 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(

static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
"Unexpected size of sketch entry.");
collective::Allreduce<collective::Operation::kSum>(
collective::GlobalSum(
info,
reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
}
Expand All @@ -222,8 +223,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
collective::Allreduce<collective::Operation::kSum>(global_feat_ptrs.data(),
global_feat_ptrs.size());
collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size());

// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
Expand All @@ -236,8 +236,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
collective::Allreduce<collective::Operation::kSum>(global_worker_ptr.data(),
global_worker_ptr.size());
collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size());
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
Expand All @@ -249,8 +248,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
collective::Allreduce<collective::Operation::kSum>(global_categories.data(),
global_categories.size());
collective::GlobalSum(info, global_categories.data(), global_categories.size());
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
Expand Down Expand Up @@ -323,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);

std::vector<typename WQSketch::Entry> global_sketches;
this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, &global_sketches);
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);

std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);

Expand Down Expand Up @@ -371,7 +369,9 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
InvalidCategory();
}
auto &cut_values = cuts->cut_values_.HostVector();
auto max_cat = *std::max_element(categories.cbegin(), categories.cend());
// With column-wise data split, the categories may be empty.
auto max_cat =
categories.empty() ? 0.0f : *std::max_element(categories.cbegin(), categories.cend());
CheckMaxCat(max_cat, categories.size());
for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) {
cut_values.push_back(i);
Expand Down
3 changes: 2 additions & 1 deletion src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,8 @@ class SketchContainerImpl {
return group_ind;
}
// Gather sketches from all workers.
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
void GatherSketchInfo(MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<bst_row_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches);
Expand Down
3 changes: 3 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
this->feature_type_names = that.feature_type_names;
auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types);
} else if (!that.feature_types.Empty()) {
this->feature_types.Resize(that.feature_types.Size());
this->feature_types.Copy(that.feature_types);
}
if (!that.feature_weights.Empty()) {
this->feature_weights.Resize(that.feature_weights.Size());
Expand Down
120 changes: 107 additions & 13 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "xgboost/linalg.h" // for Constants, Vector

namespace xgboost::tree {
template <typename ExpandEntry>
class HistEvaluator {
private:
struct NodeEntry {
Expand Down Expand Up @@ -285,10 +284,42 @@ class HistEvaluator {
return left_sum;
}

/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const num_entries = entries.size();

// First, gather all the primitive fields.
std::vector<CPUExpandEntry> all_entries(num_entries * world);
std::vector<uint32_t> cat_bits;
std::vector<std::size_t> cat_bits_sizes;
for (auto i = 0; i < num_entries; i++) {
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
}
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(CPUExpandEntry));

// Gather all the cat_bits.
auto gathered = collective::AllgatherV(cat_bits, cat_bits_sizes);

common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
// Copy the cat_bits back into all expand entries.
all_entries[i].split.cat_bits.resize(gathered.sizes[i]);
std::copy_n(gathered.result.cbegin() + gathered.offsets[i], gathered.sizes[i],
all_entries[i].split.cat_bits.begin());
});

return all_entries;
}

public:
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
common::Span<FeatureType const> feature_types, const RegTree &tree,
std::vector<ExpandEntry> *p_entries) {
std::vector<CPUExpandEntry> *p_entries) {
auto n_threads = ctx_->Threads();
auto& entries = *p_entries;
// All nodes are on the same level, so we can store the shared ptr.
Expand All @@ -306,7 +337,7 @@ class HistEvaluator {
return features[nidx_in_set]->Size();
}, grain_size);

std::vector<ExpandEntry> tloc_candidates(n_threads * entries.size());
std::vector<CPUExpandEntry> tloc_candidates(n_threads * entries.size());
for (size_t i = 0; i < entries.size(); ++i) {
for (decltype(n_threads) j = 0; j < n_threads; ++j) {
tloc_candidates[i * n_threads + j] = entries[i];
Expand Down Expand Up @@ -365,22 +396,18 @@ class HistEvaluator {
if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const num_entries = entries.size();
std::vector<ExpandEntry> buffer{num_entries * world};
std::copy_n(entries.cbegin(), num_entries, buffer.begin() + num_entries * rank);
collective::Allgather(buffer.data(), buffer.size() * sizeof(ExpandEntry));
for (auto worker = 0; worker < world; ++worker) {
auto all_entries = Allgather(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(buffer[worker * num_entries + nidx_in_set].split);
entries[nidx_in_set].split.Update(
all_entries[worker * entries.size() + nidx_in_set].split);
}
}
}
}

// Add splits to tree, handles all statistic
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
void ApplyTreeSplit(CPUExpandEntry const& candidate, RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator();
RegTree &tree = *p_tree;

Expand Down Expand Up @@ -465,6 +492,7 @@ class HistMultiEvaluator {
FeatureInteractionConstraintHost interaction_constraints_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;
bool is_col_split_{false};

private:
static double MultiCalcSplitGain(TrainParam const &param,
Expand Down Expand Up @@ -543,6 +571,57 @@ class HistMultiEvaluator {
return false;
}

/**
* @brief Gather the expand entries from all the workers.
* @param entries Local expand entries on this worker.
* @return Global expand entries gathered from all workers.
*/
std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const num_entries = entries.size();

// First, gather all the primitive fields.
std::vector<MultiExpandEntry> all_entries(num_entries * world);
std::vector<uint32_t> cat_bits;
std::vector<std::size_t> cat_bits_sizes;
std::vector<GradientPairPrecise> gradients;
for (auto i = 0; i < num_entries; i++) {
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes,
&gradients);
}
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(MultiExpandEntry));

// Gather all the cat_bits.
auto gathered_cat_bits = collective::AllgatherV(cat_bits, cat_bits_sizes);

// Gather all the gradients.
auto const num_gradients = gradients.size();
std::vector<GradientPairPrecise> all_gradients(num_gradients * world);
std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank);
collective::Allgather(all_gradients.data(), all_gradients.size() * sizeof(GradientPairPrecise));

auto const total_entries = num_entries * world;
auto const gradients_per_entry = num_gradients / num_entries;
auto const gradients_per_side = gradients_per_entry / 2;
common::ParallelFor(total_entries, ctx_->Threads(), [&] (auto i) {
// Copy the cat_bits back into all expand entries.
all_entries[i].split.cat_bits.resize(gathered_cat_bits.sizes[i]);
std::copy_n(gathered_cat_bits.result.cbegin() + gathered_cat_bits.offsets[i],
gathered_cat_bits.sizes[i], all_entries[i].split.cat_bits.begin());

// Copy the gradients back into all expand entries.
all_entries[i].split.left_sum.resize(gradients_per_side);
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry, gradients_per_side,
all_entries[i].split.left_sum.begin());
all_entries[i].split.right_sum.resize(gradients_per_side);
std::copy_n(all_gradients.cbegin() + i * gradients_per_entry + gradients_per_side,
gradients_per_side, all_entries[i].split.right_sum.begin());
});

return all_entries;
}

public:
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
Expand Down Expand Up @@ -597,6 +676,18 @@ class HistMultiEvaluator {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}

if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto all_entries = Allgather(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(
all_entries[worker * entries.size() + nidx_in_set].split);
}
}
}
}

linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
Expand Down Expand Up @@ -660,7 +751,10 @@ class HistMultiEvaluator {

explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param}, column_sampler_{std::move(sampler)}, ctx_{ctx} {
: param_{param},
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
Loading

0 comments on commit 5b69534

Please sign in to comment.