From b3a2141f46122c013dfe96e6c42fd3cd750d5ec5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 11 Mar 2023 05:20:04 +0800 Subject: [PATCH] Tests. --- src/common/quantile.cc | 5 +- tests/cpp/helpers.cc | 57 ++++++---- tests/cpp/helpers.h | 17 ++- tests/cpp/predictor/test_cpu_predictor.cc | 6 + tests/cpp/predictor/test_predictor.cc | 133 +++++++++++++++++++--- tests/cpp/predictor/test_predictor.h | 13 ++- 6 files changed, 189 insertions(+), 42 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 87eb0ec208cd..aaf271934474 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -359,6 +359,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b HistogramCuts *cuts) { size_t required_cuts = std::min(summary.size, static_cast(max_bin)); auto &cut_values = cuts->cut_values_.HostVector(); + // we use the min_value as the first (0th) element, hence starting from 1. for (size_t i = 1; i < required_cuts; ++i) { bst_float cpt = summary.data[i].value; if (i == 1 || cpt > cut_values.back()) { @@ -419,8 +420,8 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { } else { AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything - const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value - : cuts->min_vals_.HostVector()[fid]; + const bst_float cpt = + (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid]; // this must be bigger than last value in a scale const bst_float last = cpt + (fabs(cpt) + 1e-5f); cuts->cut_values_.HostVector().push_back(last); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index ebb56d2d3633..9236f569fb2c 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -224,19 +224,18 @@ std::string RandomDataGenerator::GenerateArrayInterface( return out; } -std::pair, std::string> -RandomDataGenerator::GenerateArrayInterfaceBatch( - HostDeviceVector *storage, size_t batches) const { - this->GenerateDense(storage); +std::pair, std::string> MakeArrayInterfaceBatch( + HostDeviceVector const* storage, std::size_t n_samples, bst_feature_t n_features, + std::size_t batches, std::int32_t device) { std::vector result(batches); std::vector objects; - size_t const rows_per_batch = rows_ / batches; + size_t const rows_per_batch = n_samples / batches; - auto make_interface = [storage, this](size_t offset, size_t rows) { + auto make_interface = [storage, device, n_features](std::size_t offset, std::size_t rows) { Json array_interface{Object()}; array_interface["data"] = std::vector(2); - if (device_ >= 0) { + if (device >= 0) { array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer() + offset)); array_interface["stream"] = Null{}; @@ -249,22 +248,22 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( array_interface["shape"] = std::vector(2); array_interface["shape"][0] = rows; - array_interface["shape"][1] = cols_; + array_interface["shape"][1] = n_features; array_interface["typestr"] = String(", std::string> RandomDataGenerator::GenerateArrayInterfaceBatch( + HostDeviceVector* storage, size_t batches) const { + this->GenerateDense(storage); + return MakeArrayInterfaceBatch(storage, rows_, cols_, batches, device_); +} + std::string RandomDataGenerator::GenerateColumnarArrayInterface( std::vector> *data) const { CHECK(data); @@ -400,11 +405,14 @@ int NumpyArrayIterForTest::Next() { return 1; } -std::shared_ptr -GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns){ +std::shared_ptr GetDMatrixFromData(const std::vector& x, std::size_t num_rows, + bst_feature_t num_columns) { data::DenseAdapter adapter(x.data(), num_rows, num_columns); - return std::shared_ptr(new data::SimpleDMatrix( - &adapter, std::numeric_limits::quiet_NaN(), 1)); + auto p_fmat = std::shared_ptr( + new data::SimpleDMatrix(&adapter, std::numeric_limits::quiet_NaN(), 1)); + CHECK_EQ(p_fmat->Info().num_row_, num_rows); + CHECK_EQ(p_fmat->Info().num_col_, num_columns); + return p_fmat; } std::unique_ptr CreateSparsePageDMatrix(bst_row_t n_samples, bst_feature_t n_features, @@ -572,12 +580,23 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, return gbm; } -ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, - size_t batches) : rows_{rows}, cols_{cols}, n_batches_{batches} { +ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) + : rows_{rows}, cols_{cols}, n_batches_{batches} { XGProxyDMatrixCreate(&proxy_); rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity}); + std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_); +} + +ArrayIterForTest::ArrayIterForTest(Context const* ctx, HostDeviceVector const& data, + std::size_t n_samples, bst_feature_t n_features, + std::size_t n_batches) + : rows_{n_samples}, cols_{n_features}, n_batches_{n_batches} { + XGProxyDMatrixCreate(&proxy_); + this->data_.Resize(data.Size()); + CHECK_EQ(this->data_.Size(), rows_ * cols_ * n_batches); + this->data_.Copy(data); std::tie(batches_, interface_) = - rng_->GenerateArrayInterfaceBatch(&data_, n_batches_); + MakeArrayInterfaceBatch(&data_, rows_, cols_, n_batches_, ctx->gpu_id); } ArrayIterForTest::~ArrayIterForTest() { XGDMatrixFree(proxy_); } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index ec0abf32b452..279e3f75951e 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -188,7 +188,7 @@ class SimpleRealUniformDistribution { }; template -Json GetArrayInterface(HostDeviceVector *storage, size_t rows, size_t cols) { +Json GetArrayInterface(HostDeviceVector const* storage, size_t rows, size_t cols) { Json array_interface{Object()}; array_interface["data"] = std::vector(2); if (storage->DeviceCanRead()) { @@ -318,8 +318,8 @@ GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) { return x; } -std::shared_ptr GetDMatrixFromData(const std::vector &x, - int num_rows, int num_columns); +std::shared_ptr GetDMatrixFromData(const std::vector& x, std::size_t num_rows, + bst_feature_t num_columns); /** * \brief Create Sparse Page using data iterator. @@ -394,7 +394,7 @@ typedef void *DMatrixHandle; // NOLINT(*); class ArrayIterForTest { protected: HostDeviceVector data_; - size_t iter_ {0}; + size_t iter_{0}; DMatrixHandle proxy_; std::unique_ptr rng_; @@ -418,6 +418,11 @@ class ArrayIterForTest { auto Proxy() -> decltype(proxy_) { return proxy_; } explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches); + /** + * \brief Create iterator with user provided data. + */ + explicit ArrayIterForTest(Context const* ctx, HostDeviceVector const& data, + std::size_t n_samples, bst_feature_t n_features, std::size_t n_batches); virtual ~ArrayIterForTest(); }; @@ -433,6 +438,10 @@ class NumpyArrayIterForTest : public ArrayIterForTest { public: explicit NumpyArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(), size_t batches = Batches()); + explicit NumpyArrayIterForTest(Context const* ctx, HostDeviceVector const& data, + std::size_t n_samples, bst_feature_t n_features, + std::size_t n_batches) + : ArrayIterForTest{ctx, data, n_samples, n_features, n_batches} {} int Next() override; ~NumpyArrayIterForTest() override = default; }; diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 9a0ebee18c53..401d33c4d04d 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -305,4 +305,10 @@ TEST(CpuPredictor, Sparse) { TestSparsePrediction(0.2, "cpu_predictor"); TestSparsePrediction(0.8, "cpu_predictor"); } + +TEST(CpuPredictor, Multi) { + Context ctx; + ctx.nthread = 1; + TestVectorLeafPrediction(&ctx); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 3e8a94c75ab9..4570a010df67 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -1,28 +1,34 @@ -/*! - * Copyright 2020-2021 by Contributors +/** + * Copyright 2020-2023 by XGBoost Contributors */ - #include "test_predictor.h" #include -#include -#include -#include -#include - -#include "../../../src/common/bitfield.h" -#include "../../../src/common/categorical.h" -#include "../../../src/common/io.h" -#include "../../../src/data/adapter.h" -#include "../../../src/data/proxy_dmatrix.h" -#include "../helpers.h" +#include // for Context +#include // for DMatrix, BatchIterator, BatchSet, MetaInfo +#include // for HostDeviceVector +#include // for PredictionCacheEntry, Predictor, Predic... + +#include // for max +#include // for numeric_limits +#include // for unordered_map + +#include "../../../src/common/bitfield.h" // for LBitField32 +#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix +#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy +#include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator +#include "xgboost/json.h" // for Json, Object, get, String +#include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector +#include "xgboost/logging.h" // for CHECK +#include "xgboost/span.h" // for operator!=, SpanIterator, Span +#include "xgboost/tree_model.h" // for RegTree namespace xgboost { TEST(Predictor, PredictionCache) { size_t constexpr kRows = 16, kCols = 4; PredictionContainer container; - DMatrix* m; + DMatrix *m; // Add a cache that is immediately expired. auto add_cache = [&]() { auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); @@ -412,4 +418,101 @@ void TestSparsePrediction(float sparsity, std::string predictor) { } } } + +void TestVectorLeafPrediction(Context const *ctx) { + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor", ctx)); + + size_t constexpr kRows = 5; + size_t constexpr kCols = 5; + + LearnerModelParam mparam{static_cast(kCols), + linalg::Vector{{0.5}, {1}, Context::kCpuId}, 1, 3, + MultiStrategy::kMonolithic}; + + std::vector> trees; + trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature}); + + std::vector p_w(mparam.LeafLength(), 0.0f); + std::vector l_w(mparam.LeafLength(), 1.0f); + std::vector r_w(mparam.LeafLength(), 2.0f); + + auto &tree = trees.front(); + tree->ExpandNode(0, static_cast(1), 2.0, true, + linalg::MakeVec(p_w.data(), p_w.size()), linalg::MakeVec(l_w.data(), l_w.size()), + linalg::MakeVec(r_w.data(), r_w.size())); + ASSERT_TRUE(tree->IsMultiTarget()); + ASSERT_TRUE(mparam.IsVectorLeaf()); + + gbm::GBTreeModel model{&mparam, ctx}; + model.CommitModel(std::move(trees), 0); + + auto run_test = [&](float expected, HostDeviceVector *p_data) { + { + auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); + PredictionCacheEntry predt_cache; + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + ASSERT_EQ(predt_cache.predictions.Size(), kRows * mparam.LeafLength()); + cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1); + auto const &h_predt = predt_cache.predictions.HostVector(); + for (auto v : h_predt) { + ASSERT_EQ(v, expected); + } + } + + { + // inplace + PredictionCacheEntry predt_cache; + auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + auto arr = GetArrayInterface(p_data, kRows, kCols); + std::string str; + Json::Dump(arr, &str); + auto proxy = std::shared_ptr(new data::DMatrixProxy{}); + dynamic_cast(proxy.get())->SetArrayData(str.data()); + cpu_predictor->InplacePredict(proxy, model, std::numeric_limits::quiet_NaN(), + &predt_cache, 0, 1); + auto const &h_predt = predt_cache.predictions.HostVector(); + for (auto v : h_predt) { + ASSERT_EQ(v, expected); + } + } + + { + // ghist + PredictionCacheEntry predt_cache; + auto &h_data = p_data->HostVector(); + // give it at least two bins, otherwise the histogram cuts only have min and max values. + for (std::size_t i = 0; i < 5; ++i) { + h_data[i] = 1.0; + } + auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); + + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + + auto iter = NumpyArrayIterForTest{ctx, *p_data, kRows, static_cast(kCols), + static_cast(1)}; + p_fmat = + std::make_shared(&iter, iter.Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), 0, 256); + + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1); + auto const &h_predt = predt_cache.predictions.HostVector(); + // the smallest v uses the min_value from histogram cuts, which leads to a left leaf + // during prediction. + for (std::size_t i = 5; i < h_predt.size(); ++i) { + ASSERT_EQ(h_predt[i], expected) << i; + } + } + }; + + // go to right + HostDeviceVector data(kRows * kCols, model.trees.front()->SplitCond(RegTree::kRoot) + 1.0); + run_test(2.5, &data); + + // go to left + data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0); + run_test(1.5, &data); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 61b05b31bb91..56c1523a1cf1 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -1,9 +1,16 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ #ifndef XGBOOST_TEST_PREDICTOR_H_ #define XGBOOST_TEST_PREDICTOR_H_ +#include // for Context #include -#include + #include +#include + +#include "../../../src/gbm/gbtree_model.h" // for GBTreeModel #include "../helpers.h" namespace xgboost { @@ -48,7 +55,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols, PredictionCacheEntry precise_out_predictions; predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model); predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0); - ASSERT_FALSE(p_dmat->PageExists()); + CHECK(!p_dmat->PageExists()); } } @@ -69,6 +76,8 @@ void TestCategoricalPredictLeaf(StringView name); void TestIterationRange(std::string name); void TestSparsePrediction(float sparsity, std::string predictor); + +void TestVectorLeafPrediction(Context const* ctx); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_