Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[breaking] Change DMatrix construction to be distributed #9623

Merged
merged 3 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,7 @@ class DMatrix {
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
* \param data_split_mode Indicate how the data was split beforehand.
* \return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri, bool silent = true,
Expand Down
56 changes: 5 additions & 51 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}

void MetaInfo::SynchronizeNumberOfColumns() {
if (IsVerticalFederated()) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
Expand Down Expand Up @@ -850,39 +850,13 @@ DMatrix* TryLoadBinary(std::string fname, bool silent) {
} // namespace

DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
auto need_split = false;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
need_split = true;
}

std::string fname, cache_file;
auto dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
cache_file = uri.substr(dlm_pos + 1, uri.length());
fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification.";
if (need_split && data_split_mode == DataSplitMode::kRow) {
std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) {
size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) {
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize();
} else {
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
}
if (i + 1 != cache_shards.size()) {
os << ':';
}
}
cache_file = os.str();
}
} else {
fname = uri;
}
Expand All @@ -894,19 +868,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
}

int partid = 0, npart = 1;
if (need_split && data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank();
npart = collective::GetWorldSize();
} else {
// test option to load in part
npart = 1;
}

if (npart != 1) {
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
}

DMatrix* dmat{nullptr};
DMatrix* dmat{};

if (cache_file.empty()) {
fname = data::ValidateFileFormat(fname);
Expand All @@ -916,6 +878,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file, data_split_mode);
} else {
CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is not supported for external memory.";
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
dmat = new data::SparsePageDMatrix{&iter,
iter.Proxy(),
Expand All @@ -926,17 +890,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
cache_file};
}

if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
}
LOG(CONSOLE) << "Splitting data by column";
auto* sliced = dmat->SliceCol(npart, partid);
delete dmat;
return sliced;
} else {
return dmat;
}
return dmat;
}

template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
Expand Down
4 changes: 2 additions & 2 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}

void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsVerticalFederated()) {
if (info_.IsColumnSplit()) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul);
if (offset == 0) {
return;
}
Expand Down
7 changes: 4 additions & 3 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ class SimpleDMatrix : public DMatrix {
/**
* \brief Reindex the features based on a global view.
*
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
* reindex the features based on the offset needed to obtain the global view.
* In some cases (e.g. column-wise data split and vertical federated learning), features are
* loaded locally with indices starting from 0. However, all the algorithms assume the features
* are globally indexed, so we reindex the features based on the offset needed to obtain the
* global view.
*/
void ReindexFeatures(Context const* ctx);

Expand Down
18 changes: 18 additions & 0 deletions tests/cpp/data/test_simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,21 @@ TEST(SimpleDMatrix, Threads) {
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0, "")};
ASSERT_EQ(p_fmat->Ctx()->Threads(), AllThreadsForTest());
}

namespace {
void VerifyColumnSplit() {
size_t constexpr kRows {16};
size_t constexpr kCols {8};
auto dmat =
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol);

ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize());
ASSERT_EQ(dmat->Info().num_row_, kRows);
ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol);
}
} // anonymous namespace

TEST(SimpleDMatrix, ColumnSplit) {
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit);
}
7 changes: 3 additions & 4 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,16 @@ void RandomDataGenerator::GenerateCSR(
CHECK_EQ(columns->Size(), value->Size());
}

[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
bool float_label,
size_t classes) const {
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(
bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const {
HostDeviceVector<float> data;
HostDeviceVector<bst_row_t> rptrs;
HostDeviceVector<bst_feature_t> columns;
this->GenerateCSR(&data, &rptrs, &columns);
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_,
data.Size(), cols_);
std::shared_ptr<DMatrix> out{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, "", data_split_mode)};

if (with_label) {
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ class RandomDataGenerator {
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
HostDeviceVector<bst_feature_t>* columns) const;

[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
bool float_label = true,
size_t classes = 1) const;
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(
bool with_label = false, bool float_label = true, size_t classes = 1,
DataSplitMode data_split_mode = DataSplitMode::kRow) const;

[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
bool with_label) const;
Expand Down
Loading