From 31b74c745414ffb4fb4033b9e9bc35caae71cac6 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Jun 2023 12:10:45 -0700 Subject: [PATCH 1/9] support colsplit with predict intance --- include/xgboost/predictor.h | 12 +++--- src/predictor/cpu_predictor.cc | 38 +++++++++++++++--- src/predictor/gpu_predictor.cu | 2 +- tests/cpp/predictor/test_cpu_predictor.cc | 48 +++++++---------------- 4 files changed, 55 insertions(+), 45 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 50665341a803..615bc0f398bc 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -134,16 +134,18 @@ class Predictor { * usually more efficient than online prediction This function is NOT * threadsafe, make sure you only call from one thread. * - * \param inst The instance to predict. - * \param [in,out] out_preds The output preds. - * \param model The model to predict from - * \param tree_end (Optional) The tree end index. + * \param inst The instance to predict. + * \param [in,out] out_preds The output preds. + * \param model The model to predict from + * \param tree_end (Optional) The tree end index. + * \param is_column_split (Optional) If the data is split column-wise. */ virtual void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, - unsigned tree_end = 0) const = 0; + unsigned tree_end = 0, + bool is_column_split = false) const = 0; /** * \brief predict the leaf index of each tree, the output will be nsample * diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index aa89729897fa..00009fea2c29 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -191,6 +191,18 @@ struct SparsePageView { size_t Size() const { return view.Size(); } }; +struct SingleInstanceView { + bst_row_t base_rowid{}; + SparsePage::Inst const &inst; + + explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {} + SparsePage::Inst operator[](size_t i) { + CHECK_EQ(i, 0); + return inst; + } + static size_t Size() { return 1; } +}; + struct GHistIndexMatrixView { private: GHistIndexMatrix const &page_; @@ -409,6 +421,13 @@ class ColumnSplitHelper { } } + void PredictInstance(SparsePage::Inst const &inst, std::vector *out_preds) { + CHECK(xgboost::collective::IsDistributed()) + << "column-split prediction is only supported for distributed training"; + + PredictBatchKernel(SingleInstanceView{inst}, out_preds); + } + private: using BitVector = RBitField8; @@ -728,18 +747,25 @@ class CPUPredictor : public Predictor { return true; } - void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, + const gbm::GBTreeModel &model, unsigned ntree_limit, + bool is_column_split) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented(); - std::vector feat_vecs; - feat_vecs.resize(1, RegTree::FVec()); - feat_vecs[0].Init(model.learner_model_param->num_feature); ntree_limit *= model.learner_model_param->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } out_preds->resize(model.learner_model_param->num_output_group); + + if (is_column_split) { + ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit); + helper.PredictInstance(inst, out_preds); + return; + } + + std::vector feat_vecs; + feat_vecs.resize(1, RegTree::FVec()); + feat_vecs[0].Init(model.learner_model_param->num_feature); auto base_score = model.learner_model_param->BaseScore(ctx_)(0); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 11662f9b8608..4b834e78fb90 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -929,7 +929,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictInstance(const SparsePage::Inst&, std::vector*, - const gbm::GBTreeModel&, unsigned) const override { + const gbm::GBTreeModel&, unsigned, bool) const override { LOG(FATAL) << "[Internal error]: " << __func__ << " is not implemented in GPU Predictor."; } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 401d33c4d04d..84538f1f2344 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -17,7 +17,10 @@ #include "test_predictor.h" namespace xgboost { -TEST(CpuPredictor, Basic) { + +namespace { +template +void TestBasic() { auto lparam = CreateEmptyGenericParam(GPUIDX); std::unique_ptr cpu_predictor = std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); @@ -32,6 +35,11 @@ TEST(CpuPredictor, Basic) { gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + if constexpr (is_column_split) { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + dmat = std::shared_ptr{dmat->SliceCol(world_size, rank)}; + } // Test predict batch PredictionCacheEntry out_predictions; @@ -48,7 +56,7 @@ TEST(CpuPredictor, Basic) { auto page = batch.GetView(); for (size_t i = 0; i < batch.Size(); i++) { std::vector instance_out_predictions; - cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model); + cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model, 0, is_column_split); ASSERT_EQ(instance_out_predictions[0], 1.5); } @@ -89,41 +97,15 @@ TEST(CpuPredictor, Basic) { } } } +} // anonymous namespace -namespace { -void TestColumnSplitPredictBatch() { - size_t constexpr kRows = 5; - size_t constexpr kCols = 5; - auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - - auto lparam = CreateEmptyGenericParam(GPUIDX); - std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - - LearnerModelParam mparam{MakeMP(kCols, .0, 1)}; - - Context ctx; - ctx.UpdateAllowUnknown(Args{}); - gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); - - // Test predict batch - PredictionCacheEntry out_predictions; - cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - auto sliced = std::unique_ptr{dmat->SliceCol(world_size, rank)}; - cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0); - - std::vector& out_predictions_h = out_predictions.predictions.HostVector(); - for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { - ASSERT_EQ(out_predictions_h[i], 1.5); - } +TEST(CpuPredictor, Basic) { + TestBasic(); } -} // anonymous namespace -TEST(CpuPredictor, ColumnSplit) { +TEST(CpuPredictor, ColumnSplitBasic) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestColumnSplitPredictBatch); + RunWithInMemoryCommunicator(kWorldSize, TestBasic); } TEST(CpuPredictor, IterationRange) { From f201942f6933c5be0e7638edc58dac7f9aba4fa3 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Jun 2023 13:10:26 -0700 Subject: [PATCH 2/9] support predict leaf --- include/xgboost/predictor.h | 12 +++--- src/predictor/cpu_predictor.cc | 46 ++++++++++++++++++----- src/predictor/gpu_predictor.cu | 3 +- tests/cpp/predictor/test_cpu_predictor.cc | 2 +- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 615bc0f398bc..2cdb1af9f65e 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -151,15 +151,17 @@ class Predictor { * \brief predict the leaf index of each tree, the output will be nsample * * ntree vector this is only valid in gbtree predictor. * - * \param [in,out] dmat The input feature matrix. - * \param [in,out] out_preds The output preds. - * \param model Model to make predictions from. - * \param tree_end (Optional) The tree end index. + * \param [in,out] dmat The input feature matrix. + * \param [in,out] out_preds The output preds. + * \param model Model to make predictions from. + * \param tree_end (Optional) The tree end index. + * \param is_column_split (Optional) If the data is split column-wise. */ virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, - unsigned tree_end = 0) const = 0; + unsigned tree_end = 0, + bool is_column_split = false) const = 0; /** * \brief feature contributions to individual predictions; the output will be diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 00009fea2c29..ab4ba781a4a0 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -428,6 +428,17 @@ class ColumnSplitHelper { PredictBatchKernel(SingleInstanceView{inst}, out_preds); } + void PredictLeaf(DMatrix *p_fmat, std::vector *out_preds) { + CHECK(xgboost::collective::IsDistributed()) + << "column-split prediction is only supported for distributed training"; + + for (auto const &batch : p_fmat->GetBatches()) { + CHECK_EQ(out_preds->size(), + p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group); + PredictBatchKernel(SparsePageView{&batch}, out_preds); + } + } + private: using BitVector = RBitField8; @@ -517,24 +528,31 @@ class ColumnSplitHelper { return nid; } + template bst_float PredictOneTree(std::size_t tree_id, std::size_t row_id) { auto const &tree = *model_.trees[tree_id]; auto const leaf = GetLeafIndex(tree, tree_id, row_id); - return tree[leaf].LeafValue(); + if constexpr (predict_leaf) { + return static_cast(leaf); + } else { + return tree[leaf].LeafValue(); + } } + template void PredictAllTrees(std::vector *out_preds, std::size_t batch_offset, std::size_t predict_offset, std::size_t num_group, std::size_t block_size) { auto &preds = *out_preds; for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) { auto const gid = model_.tree_info[tree_id]; for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += PredictOneTree(tree_id, batch_offset + i); + preds[(predict_offset + i) * num_group + gid] += + PredictOneTree(tree_id, batch_offset + i); } } } - template + template void PredictBatchKernel(DataView batch, std::vector *out_preds) { auto const num_group = model_.learner_model_param->num_output_group; @@ -563,8 +581,8 @@ class ColumnSplitHelper { auto const batch_offset = block_id * block_of_rows_size; auto const block_size = std::min(static_cast(nsize - batch_offset), static_cast(block_of_rows_size)); - PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, num_group, - block_size); + PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, + num_group, block_size); }); ClearBitVectors(); @@ -776,18 +794,26 @@ class CPUPredictor : public Predictor { } void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, - const gbm::GBTreeModel &model, unsigned ntree_limit) const override { + const gbm::GBTreeModel &model, unsigned ntree_limit, + bool is_column_split) const override { auto const n_threads = this->ctx_->Threads(); - std::vector feat_vecs; - const int num_feature = model.learner_model_param->num_feature; - InitThreadTemp(n_threads, &feat_vecs); - const MetaInfo &info = p_fmat->Info(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } + const MetaInfo &info = p_fmat->Info(); std::vector &preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); + + if (is_column_split) { + ColumnSplitHelper helper(n_threads, model, 0, ntree_limit); + helper.PredictLeaf(p_fmat, &preds); + return; + } + + std::vector feat_vecs; + const int num_feature = model.learner_model_param->num_feature; + InitThreadTemp(n_threads, &feat_vecs); // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { // parallel over local batch diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 4b834e78fb90..1b3d625665c6 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -936,7 +936,8 @@ class GPUPredictor : public xgboost::Predictor { void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *predictions, const gbm::GBTreeModel &model, - unsigned tree_end) const override { + unsigned tree_end, + bool is_column_split) const override { dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 84538f1f2344..bc9051c395b2 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -62,7 +62,7 @@ void TestBasic() { // Test predict leaf HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); + cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model, 0, is_column_split); auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); for (auto v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); From 35c379d323333b1b0a6ec961d5de87834512940f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Jun 2023 15:25:42 -0700 Subject: [PATCH 3/9] no need to change predict leaf interface --- include/xgboost/predictor.h | 4 +--- src/predictor/cpu_predictor.cc | 5 ++--- src/predictor/gpu_predictor.cu | 3 +-- tests/cpp/predictor/test_cpu_predictor.cc | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 2cdb1af9f65e..c740a35cda5b 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -155,13 +155,11 @@ class Predictor { * \param [in,out] out_preds The output preds. * \param model Model to make predictions from. * \param tree_end (Optional) The tree end index. - * \param is_column_split (Optional) If the data is split column-wise. */ virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, - unsigned tree_end = 0, - bool is_column_split = false) const = 0; + unsigned tree_end = 0) const = 0; /** * \brief feature contributions to individual predictions; the output will be diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index ab4ba781a4a0..3993ce3f8dad 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -794,8 +794,7 @@ class CPUPredictor : public Predictor { } void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, - const gbm::GBTreeModel &model, unsigned ntree_limit, - bool is_column_split) const override { + const gbm::GBTreeModel &model, unsigned ntree_limit) const override { auto const n_threads = this->ctx_->Threads(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { @@ -805,7 +804,7 @@ class CPUPredictor : public Predictor { std::vector &preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); - if (is_column_split) { + if (p_fmat->Info().IsColumnSplit()) { ColumnSplitHelper helper(n_threads, model, 0, ntree_limit); helper.PredictLeaf(p_fmat, &preds); return; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 1b3d625665c6..4b834e78fb90 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -936,8 +936,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *predictions, const gbm::GBTreeModel &model, - unsigned tree_end, - bool is_column_split) const override { + unsigned tree_end) const override { dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index bc9051c395b2..35f7c908818b 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -62,7 +62,7 @@ void TestBasic() { // Test predict leaf HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model, 0, is_column_split); + cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model, 0); auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); for (auto v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); From 75a1629b78fe09b3e31cb05b34d160c6c68cdc0e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Jun 2023 15:55:27 -0700 Subject: [PATCH 4/9] skip predict contribution --- src/predictor/cpu_predictor.cc | 2 ++ tests/cpp/predictor/test_cpu_predictor.cc | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 3993ce3f8dad..bdab7a039e96 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -847,6 +847,8 @@ class CPUPredictor : public Predictor { int condition, unsigned condition_feature) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; auto const n_threads = this->ctx_->Threads(); const int num_feature = model.learner_model_param->num_feature; std::vector feat_vecs; diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 35f7c908818b..a1a580044c0d 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -68,6 +68,11 @@ void TestBasic() { ASSERT_EQ(v, 0); } + if (is_column_split) { + // Predict contribution is not supported for column split. + return; + } + // Test predict contribution HostDeviceVector out_contribution_hdv; auto& out_contribution = out_contribution_hdv.HostVector(); From c6e0a54d3dc255c2bd9401185ba7ae29f1258a02 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 1 Jun 2023 16:16:20 -0700 Subject: [PATCH 5/9] clean up basic tests --- tests/cpp/predictor/test_cpu_predictor.cc | 114 ++++++---------------- 1 file changed, 31 insertions(+), 83 deletions(-) diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index a1a580044c0d..b7c4b92f3369 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -19,14 +19,13 @@ namespace xgboost { namespace { -template -void TestBasic() { +void TestBasic(DMatrix* dmat) { auto lparam = CreateEmptyGenericParam(GPUIDX); std::unique_ptr cpu_predictor = std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - size_t constexpr kRows = 5; - size_t constexpr kCols = 5; + size_t const kRows = dmat->Info().num_row_; + size_t const kCols = dmat->Info().num_col_; LearnerModelParam mparam{MakeMP(kCols, .0, 1)}; @@ -34,17 +33,10 @@ void TestBasic() { ctx.UpdateAllowUnknown(Args{}); gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); - auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - if constexpr (is_column_split) { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - dmat = std::shared_ptr{dmat->SliceCol(world_size, rank)}; - } - // Test predict batch PredictionCacheEntry out_predictions; cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + cpu_predictor->PredictBatch(dmat, &out_predictions, model, 0); std::vector& out_predictions_h = out_predictions.predictions.HostVector(); for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { @@ -52,23 +44,24 @@ void TestBasic() { } // Test predict instance - auto const &batch = *dmat->GetBatches().begin(); + auto const& batch = *dmat->GetBatches().begin(); auto page = batch.GetView(); for (size_t i = 0; i < batch.Size(); i++) { std::vector instance_out_predictions; - cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model, 0, is_column_split); + cpu_predictor->PredictInstance(page[i], &instance_out_predictions, model, 0, + dmat->Info().IsColumnSplit()); ASSERT_EQ(instance_out_predictions[0], 1.5); } // Test predict leaf HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model, 0); + cpu_predictor->PredictLeaf(dmat, &leaf_out_predictions, model, 0); auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); for (auto v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); } - if (is_column_split) { + if (dmat->Info().IsColumnSplit()) { // Predict contribution is not supported for column split. return; } @@ -76,7 +69,7 @@ void TestBasic() { // Test predict contribution HostDeviceVector out_contribution_hdv; auto& out_contribution = out_contribution_hdv.HostVector(); - cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model); + cpu_predictor->PredictContribution(dmat, &out_contribution_hdv, model); ASSERT_EQ(out_contribution.size(), kRows * (kCols + 1)); for (size_t i = 0; i < out_contribution.size(); ++i) { auto const& contri = out_contribution[i]; @@ -89,8 +82,7 @@ void TestBasic() { } } // Test predict contribution (approximate method) - cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model, - 0, nullptr, true); + cpu_predictor->PredictContribution(dmat, &out_contribution_hdv, model, 0, nullptr, true); for (size_t i = 0; i < out_contribution.size(); ++i) { auto const& contri = out_contribution[i]; // shift 1 for bias, as test tree is a decision dump, only global bias is @@ -105,12 +97,29 @@ void TestBasic() { } // anonymous namespace TEST(CpuPredictor, Basic) { - TestBasic(); + size_t constexpr kRows = 5; + size_t constexpr kCols = 5; + auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + TestBasic(dmat.get()); +} + +namespace { +void TestColumnSplit() { + size_t constexpr kRows = 5; + size_t constexpr kCols = 5; + auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); + + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + dmat = std::unique_ptr{dmat->SliceCol(world_size, rank)}; + + TestBasic(dmat.get()); } +} // anonymous namespace TEST(CpuPredictor, ColumnSplitBasic) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestBasic); + RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit); } TEST(CpuPredictor, IterationRange) { @@ -120,69 +129,8 @@ TEST(CpuPredictor, IterationRange) { TEST(CpuPredictor, ExternalMemory) { size_t constexpr kPageSize = 64, kEntriesPerCol = 3; size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries); - auto lparam = CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr cpu_predictor = - std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - - LearnerModelParam mparam{MakeMP(dmat->Info().num_col_, .0, 1)}; - - Context ctx; - ctx.UpdateAllowUnknown(Args{}); - gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); - - // Test predict batch - PredictionCacheEntry out_predictions; - cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); - std::vector &out_predictions_h = out_predictions.predictions.HostVector(); - ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_); - for (const auto& v : out_predictions_h) { - ASSERT_EQ(v, 1.5); - } - - // Test predict leaf - HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat.get(), &leaf_out_predictions, model); - auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); - ASSERT_EQ(h_leaf_out_predictions.size(), dmat->Info().num_row_); - for (const auto& v : h_leaf_out_predictions) { - ASSERT_EQ(v, 0); - } - - // Test predict contribution - HostDeviceVector out_contribution_hdv; - auto& out_contribution = out_contribution_hdv.HostVector(); - cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model); - ASSERT_EQ(out_contribution.size(), dmat->Info().num_row_ * (dmat->Info().num_col_ + 1)); - for (size_t i = 0; i < out_contribution.size(); ++i) { - auto const& contri = out_contribution[i]; - // shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue(). - if ((i + 1) % (dmat->Info().num_col_ + 1) == 0) { - ASSERT_EQ(out_contribution.back(), 1.5f); - } else { - ASSERT_EQ(contri, 0); - } - } - - // Test predict contribution (approximate method) - HostDeviceVector out_contribution_approximate_hdv; - auto& out_contribution_approximate = out_contribution_approximate_hdv.HostVector(); - cpu_predictor->PredictContribution( - dmat.get(), &out_contribution_approximate_hdv, model, 0, nullptr, true); - ASSERT_EQ(out_contribution_approximate.size(), - dmat->Info().num_row_ * (dmat->Info().num_col_ + 1)); - for (size_t i = 0; i < out_contribution.size(); ++i) { - auto const& contri = out_contribution[i]; - // shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue(). - if ((i + 1) % (dmat->Info().num_col_ + 1) == 0) { - ASSERT_EQ(out_contribution.back(), 1.5f); - } else { - ASSERT_EQ(contri, 0); - } - } + TestBasic(dmat.get()); } TEST(CpuPredictor, InplacePredict) { From ac475f74014e986b7425e5874dda683ac6667fbb Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 2 Jun 2023 12:08:20 -0700 Subject: [PATCH 6/9] remove unnecessary parameter --- tests/cpp/predictor/test_cpu_predictor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index b7c4b92f3369..279ba6118bde 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -55,7 +55,7 @@ void TestBasic(DMatrix* dmat) { // Test predict leaf HostDeviceVector leaf_out_predictions; - cpu_predictor->PredictLeaf(dmat, &leaf_out_predictions, model, 0); + cpu_predictor->PredictLeaf(dmat, &leaf_out_predictions, model); auto const& h_leaf_out_predictions = leaf_out_predictions.ConstHostVector(); for (auto v : h_leaf_out_predictions) { ASSERT_EQ(v, 0); From b4b5694481010eca8031afa5d795102666e8f2f6 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 2 Jun 2023 20:59:20 -0700 Subject: [PATCH 7/9] revert whitespaces --- include/xgboost/predictor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index c740a35cda5b..615bc0f398bc 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -151,10 +151,10 @@ class Predictor { * \brief predict the leaf index of each tree, the output will be nsample * * ntree vector this is only valid in gbtree predictor. * - * \param [in,out] dmat The input feature matrix. - * \param [in,out] out_preds The output preds. - * \param model Model to make predictions from. - * \param tree_end (Optional) The tree end index. + * \param [in,out] dmat The input feature matrix. + * \param [in,out] out_preds The output preds. + * \param model Model to make predictions from. + * \param tree_end (Optional) The tree end index. */ virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, From 9452aec8f0fde1b049dbb02c434d009abe5053d8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 2 Jun 2023 21:02:23 -0700 Subject: [PATCH 8/9] add check for interaction contribution --- src/predictor/cpu_predictor.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index bdab7a039e96..14fd918e2780 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -930,6 +930,8 @@ class CPUPredictor : public Predictor { bool approximate) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " + "column-wise data split is not yet implemented."; const MetaInfo& info = p_fmat->Info(); const int ngroup = model.learner_model_param->num_output_group; size_t const ncolumns = model.learner_model_param->num_feature; From 73e719596c387c75bbdbf9ca272d2aa22f736ad2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 3 Jun 2023 11:08:17 -0700 Subject: [PATCH 9/9] review feedback --- src/predictor/cpu_predictor.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 14fd918e2780..96c1fbe18cb5 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -196,10 +196,7 @@ struct SingleInstanceView { SparsePage::Inst const &inst; explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {} - SparsePage::Inst operator[](size_t i) { - CHECK_EQ(i, 0); - return inst; - } + SparsePage::Inst operator[](size_t) { return inst; } static size_t Size() { return 1; } };