Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vertical federated learning #8932

Merged
merged 9 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ class MetaInfo {
*/
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);

/**
* @brief Synchronize the number of columns across all workers.
*
* Normally we just need to find the maximum number of columns across all workers, but
* in vertical federated learning, since each worker loads its own list of columns,
* we need to sum them.
*/
void SynchronizeNumberOfColumns();

private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
Expand Down Expand Up @@ -325,6 +334,10 @@ class SparsePage {
* \brief Check wether the column index is sorted.
*/
bool IsIndicesSorted(int32_t n_threads) const;
/**
* \brief Reindex the column index with an offset.
*/
void Reindex(uint64_t feature_offset, int32_t n_threads);

void SortRows(int32_t n_threads);

Expand Down Expand Up @@ -632,6 +645,17 @@ class DMatrix {
*/
virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0;

/**
* \brief Reindex the features based on a global view.
*
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
* reindex the features based on the offset needed to obtain the global view.
*
* \param offset The offset to be added to the feature index
*/
virtual void ReindexFeatures(uint64_t offset) = 0;

protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
Expand Down
31 changes: 26 additions & 5 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,14 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}

void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
}
}

void MetaInfo::Validate(std::int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
Expand Down Expand Up @@ -903,10 +911,17 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
LOG(FATAL) << "Encountered parser error:\n" << e.what();
}

/* sync up number of features after matrix loaded.
* partitioned data will fail the train/val validation check
* since partitioned data not knowing the real number of features. */
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);

if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = dmat->Info().num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
dmat->ReindexFeatures(offset);
}

dmat->Info().data_split_mode = data_split_mode;
dmat->Info().SynchronizeNumberOfColumns();

if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
Expand All @@ -917,7 +932,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
delete dmat;
return sliced;
} else {
dmat->Info().data_split_mode = data_split_mode;
return dmat;
}
}
Expand Down Expand Up @@ -1048,6 +1062,13 @@ void SparsePage::SortIndices(int32_t n_threads) {
});
}

void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) {
auto& h_data = this->data.HostVector();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This potentially pulls data from the device to the host.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but it's not much different from some of the other methods there.

common::ParallelFor(h_data.size(), n_threads, [&](auto i) {
h_data[i].index += feature_offset;
});
}

void SparsePage::SortRows(int32_t n_threads) {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
Expand Down
2 changes: 1 addition & 1 deletion src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
Info().SynchronizeNumberOfColumns();
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
Expand Down
2 changes: 1 addition & 1 deletion src/data/iterative_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,

iter.Reset();
// Synchronise worker columns
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.SynchronizeNumberOfColumns();
}

BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
Expand Down
3 changes: 3 additions & 0 deletions src/data/iterative_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class IterativeDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
return nullptr;
}
void ReindexFeatures(uint64_t offset) override {
LOG(FATAL) << "Reindexing features is not supported for Quantile DMatrix.";
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
Expand Down
3 changes: 3 additions & 0 deletions src/data/proxy_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
return nullptr;
}
void ReindexFeatures(uint64_t offset) override {
LOG(FATAL) << "Reindexing features is not supported for Proxy DMatrix.";
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
Expand Down
13 changes: 8 additions & 5 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
return out;
}

void SimpleDMatrix::ReindexFeatures(uint64_t offset) {
if (offset == 0) {
return;
}
sparse_page_->Reindex(offset, Ctx()->Threads());
}

BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(
Expand Down Expand Up @@ -215,10 +222,6 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
info_.num_col_ = adapter->NumColumns();
}


// Synchronise worker columns
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);

if (adapter->NumRows() == kAdapterUnknownSize) {
using IteratorAdapterT
= IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
Expand Down Expand Up @@ -346,7 +349,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
}
// Synchronise worker columns
info_.num_col_ = adapter->NumColumns();
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.SynchronizeNumberOfColumns();
info_.num_row_ = total_batch_size;
info_.num_nonzero_ = data_vec.size();
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.SynchronizeNumberOfColumns();
}

template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
Expand Down
1 change: 1 addition & 0 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SimpleDMatrix : public DMatrix {
bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
DMatrix* SliceCol(int num_slices, int slice_id) override;
void ReindexFeatures(uint64_t offset) override;

/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;
Expand Down
2 changes: 1 addition & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;

collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.SynchronizeNumberOfColumns();
CHECK_NE(info_.num_col_, 0);
}

Expand Down
3 changes: 3 additions & 0 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class SparsePageDMatrix : public DMatrix {
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
return nullptr;
}
void ReindexFeatures(uint64_t offset) override {
LOG(FATAL) << "Reindexing features is not supported for external memory.";
}

private:
BatchSet<SparsePage> GetRowBatches() override;
Expand Down
45 changes: 43 additions & 2 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class LearnerConfiguration : public Learner {
info.Validate(Ctx()->gpu_id);
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
UsePtr(obj_)->InitEstimation(info, &base_score);
InitEstimation(info, &base_score);
CHECK_EQ(base_score.Size(), 1);
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
Expand Down Expand Up @@ -857,6 +857,25 @@ class LearnerConfiguration : public Learner {
mparam_.num_target = n_targets;
}
}

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we just calculate the gradient using individual workers? Is the gradient still the same? If so, we can just let them calculate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't have labels in non-0 workers, they won't be able to calculate the gradient.

// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and added to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
} else {
base_score->Reshape(1);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
}
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
}
};

std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
Expand Down Expand Up @@ -1303,7 +1322,7 @@ class LearnerImpl : public LearnerIO {
monitor_.Stop("PredictRaw");

monitor_.Start("GetGradient");
obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_);
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients");

Expand Down Expand Up @@ -1482,6 +1501,28 @@ class LearnerImpl : public LearnerIO {
}

private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
obj_->GetGradient(preds, info, iteration, out_gpair);
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
} else {
CHECK_EQ(info.labels.Size(), 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be difficult for users to specify their own worker rank once we put xgboost in an automated pipeline. I look at your nvflare example, the rank is not assigned by user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can check if the label size is 0 here to determine who needs to calculate the gradient. But in general we need stable ranks for the trained model to be useful for inference. That's more of an nvflare requirement. I'll ask them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to automatically agree on who should be the one to own the label? Maybe it's easier to have a fully automated pipeline if everyone has equal access to labels? Just curious from a user's perspective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometime (most times?) it's not possible for all the parties to have access to the labels. For example, a hospital may have the diagnosis results of a patient, but labs only have access to blood work, DNA tests, etc.

I think the best way to guarantee the ordering for now is to always launch the workers in the same sequence. Since federated learning is usually done by a single admin, this is reasonable solution. I'll ask the NVFLARE team to see if they can add some new features to better support this.

<< "In vertical federated learning, labels should only be on the first worker";
out_gpair->Resize(preds.Size());
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
}
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
}

/*! \brief random number transformation seed. */
static int32_t constexpr kRandSeedMagic = 127;
// gradient pairs
Expand Down
2 changes: 1 addition & 1 deletion src/objective/init_estimation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight);

// workaround, we don't support multi-target due to binary model serialization for
// base margin.
Expand Down
15 changes: 10 additions & 5 deletions src/tree/fit_stump.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
namespace xgboost {
namespace tree {
namespace cpu_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
void FitStump(Context const* ctx, MetaInfo const& info,
linalg::TensorView<GradientPair const, 2> gpair,
linalg::VectorView<float> out) {
auto n_targets = out.Size();
CHECK_EQ(n_targets, gpair.Shape(1));
Expand All @@ -43,8 +44,12 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
}
}
CHECK(h_sum.CContiguous());
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);

// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
Copy link
Member

@trivialfis trivialfis Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can simply run it for all workers to remove a condition? We have a validate method in the learner model param, which is a good place for checking whether the label is correctly distributed across workers for federated learning. If labels are the same for all workers, the base_score should also be the same. Also, we don't need an additional info parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is in learner.cc we only call InitEstimation for worker 0, which in turn calls this method. If we don't skip this allreduce we'd get a mismatch in non-0 workers.

if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}

for (std::size_t i = 0; i < h_sum.Size(); ++i) {
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
Expand All @@ -64,15 +69,15 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl

void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out) {
out->SetDevice(ctx->gpu_id);
out->Reshape(n_targets);
auto n_samples = gpair.Size() / n_targets;

gpair.SetDevice(ctx->gpu_id);
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
}
} // namespace tree
Expand Down
3 changes: 2 additions & 1 deletion src/tree/fit_stump.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "../common/common.h" // AssertGPUSupport
#include "xgboost/base.h" // GradientPair
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // TensorView

Expand All @@ -30,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
/**
* @brief Fit a tree stump as an estimation of base_score.
*/
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out);
} // namespace tree
} // namespace xgboost
Expand Down
6 changes: 6 additions & 0 deletions tests/cpp/plugin/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ int GenerateRandomPort(int low, int high) {
int port = dist(rng);
return port;
}

std::string GetServerAddress() {
int port = GenerateRandomPort(50000, 60000);
std::string address = std::string("localhost:") + std::to_string(port);
return address;
}
Loading