Skip to content

Commit

Permalink
Tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 13, 2023
1 parent be1c4d5 commit b3a2141
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 42 deletions.
5 changes: 3 additions & 2 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
HistogramCuts *cuts) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
auto &cut_values = cuts->cut_values_.HostVector();
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
Expand Down Expand Up @@ -419,8 +420,8 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
} else {
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
// push a value that is greater than anything
const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value
: cuts->min_vals_.HostVector()[fid];
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
cuts->cut_values_.HostVector().push_back(last);
Expand Down
57 changes: 38 additions & 19 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,18 @@ std::string RandomDataGenerator::GenerateArrayInterface(
return out;
}

std::pair<std::vector<std::string>, std::string>
RandomDataGenerator::GenerateArrayInterfaceBatch(
HostDeviceVector<float> *storage, size_t batches) const {
this->GenerateDense(storage);
std::pair<std::vector<std::string>, std::string> MakeArrayInterfaceBatch(
HostDeviceVector<float> const* storage, std::size_t n_samples, bst_feature_t n_features,
std::size_t batches, std::int32_t device) {
std::vector<std::string> result(batches);
std::vector<Json> objects;

size_t const rows_per_batch = rows_ / batches;
size_t const rows_per_batch = n_samples / batches;

auto make_interface = [storage, this](size_t offset, size_t rows) {
auto make_interface = [storage, device, n_features](std::size_t offset, std::size_t rows) {
Json array_interface{Object()};
array_interface["data"] = std::vector<Json>(2);
if (device_ >= 0) {
if (device >= 0) {
array_interface["data"][0] =
Integer(reinterpret_cast<int64_t>(storage->DevicePointer() + offset));
array_interface["stream"] = Null{};
Expand All @@ -249,22 +248,22 @@ RandomDataGenerator::GenerateArrayInterfaceBatch(

array_interface["shape"] = std::vector<Json>(2);
array_interface["shape"][0] = rows;
array_interface["shape"][1] = cols_;
array_interface["shape"][1] = n_features;

array_interface["typestr"] = String("<f4");
array_interface["version"] = 3;
return array_interface;
};

auto j_interface = make_interface(0, rows_);
auto j_interface = make_interface(0, n_samples);
size_t offset = 0;
for (size_t i = 0; i < batches - 1; ++i) {
objects.emplace_back(make_interface(offset, rows_per_batch));
offset += rows_per_batch * cols_;
offset += rows_per_batch * n_features;
}

size_t const remaining = rows_ - offset / cols_;
CHECK_LE(offset, rows_ * cols_);
size_t const remaining = n_samples - offset / n_features;
CHECK_LE(offset, n_samples * n_features);
objects.emplace_back(make_interface(offset, remaining));

for (size_t i = 0; i < batches; ++i) {
Expand All @@ -276,6 +275,12 @@ RandomDataGenerator::GenerateArrayInterfaceBatch(
return {result, interface_str};
}

std::pair<std::vector<std::string>, std::string> RandomDataGenerator::GenerateArrayInterfaceBatch(
HostDeviceVector<float>* storage, size_t batches) const {
this->GenerateDense(storage);
return MakeArrayInterfaceBatch(storage, rows_, cols_, batches, device_);
}

std::string RandomDataGenerator::GenerateColumnarArrayInterface(
std::vector<HostDeviceVector<float>> *data) const {
CHECK(data);
Expand Down Expand Up @@ -400,11 +405,14 @@ int NumpyArrayIterForTest::Next() {
return 1;
}

std::shared_ptr<DMatrix>
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns){
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
bst_feature_t num_columns) {
data::DenseAdapter adapter(x.data(), num_rows, num_columns);
return std::shared_ptr<DMatrix>(new data::SimpleDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
auto p_fmat = std::shared_ptr<DMatrix>(
new data::SimpleDMatrix(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
CHECK_EQ(p_fmat->Info().num_row_, num_rows);
CHECK_EQ(p_fmat->Info().num_col_, num_columns);
return p_fmat;
}

std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_row_t n_samples, bst_feature_t n_features,
Expand Down Expand Up @@ -572,12 +580,23 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs,
return gbm;
}

ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols,
size_t batches) : rows_{rows}, cols_{cols}, n_batches_{batches} {
ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches)
: rows_{rows}, cols_{cols}, n_batches_{batches} {
XGProxyDMatrixCreate(&proxy_);
rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity});
std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
}

ArrayIterForTest::ArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
std::size_t n_samples, bst_feature_t n_features,
std::size_t n_batches)
: rows_{n_samples}, cols_{n_features}, n_batches_{n_batches} {
XGProxyDMatrixCreate(&proxy_);
this->data_.Resize(data.Size());
CHECK_EQ(this->data_.Size(), rows_ * cols_ * n_batches);
this->data_.Copy(data);
std::tie(batches_, interface_) =
rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
MakeArrayInterfaceBatch(&data_, rows_, cols_, n_batches_, ctx->gpu_id);
}

ArrayIterForTest::~ArrayIterForTest() { XGDMatrixFree(proxy_); }
Expand Down
17 changes: 13 additions & 4 deletions tests/cpp/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class SimpleRealUniformDistribution {
};

template <typename T>
Json GetArrayInterface(HostDeviceVector<T> *storage, size_t rows, size_t cols) {
Json GetArrayInterface(HostDeviceVector<T> const* storage, size_t rows, size_t cols) {
Json array_interface{Object()};
array_interface["data"] = std::vector<Json>(2);
if (storage->DeviceCanRead()) {
Expand Down Expand Up @@ -318,8 +318,8 @@ GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
return x;
}

std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float> &x,
int num_rows, int num_columns);
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
bst_feature_t num_columns);

/**
* \brief Create Sparse Page using data iterator.
Expand Down Expand Up @@ -394,7 +394,7 @@ typedef void *DMatrixHandle; // NOLINT(*);
class ArrayIterForTest {
protected:
HostDeviceVector<float> data_;
size_t iter_ {0};
size_t iter_{0};
DMatrixHandle proxy_;
std::unique_ptr<RandomDataGenerator> rng_;

Expand All @@ -418,6 +418,11 @@ class ArrayIterForTest {
auto Proxy() -> decltype(proxy_) { return proxy_; }

explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches);
/**
* \brief Create iterator with user provided data.
*/
explicit ArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
std::size_t n_samples, bst_feature_t n_features, std::size_t n_batches);
virtual ~ArrayIterForTest();
};

Expand All @@ -433,6 +438,10 @@ class NumpyArrayIterForTest : public ArrayIterForTest {
public:
explicit NumpyArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(),
size_t batches = Batches());
explicit NumpyArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
std::size_t n_samples, bst_feature_t n_features,
std::size_t n_batches)
: ArrayIterForTest{ctx, data, n_samples, n_features, n_batches} {}
int Next() override;
~NumpyArrayIterForTest() override = default;
};
Expand Down
6 changes: 6 additions & 0 deletions tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,10 @@ TEST(CpuPredictor, Sparse) {
TestSparsePrediction(0.2, "cpu_predictor");
TestSparsePrediction(0.8, "cpu_predictor");
}

TEST(CpuPredictor, Multi) {
Context ctx;
ctx.nthread = 1;
TestVectorLeafPrediction(&ctx);
}
} // namespace xgboost
133 changes: 118 additions & 15 deletions tests/cpp/predictor/test_predictor.cc
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
/*!
* Copyright 2020-2021 by Contributors
/**
* Copyright 2020-2023 by XGBoost Contributors
*/

#include "test_predictor.h"

#include <gtest/gtest.h>
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/predictor.h>

#include "../../../src/common/bitfield.h"
#include "../../../src/common/categorical.h"
#include "../../../src/common/io.h"
#include "../../../src/data/adapter.h"
#include "../../../src/data/proxy_dmatrix.h"
#include "../helpers.h"
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for DMatrix, BatchIterator, BatchSet, MetaInfo
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/predictor.h> // for PredictionCacheEntry, Predictor, Predic...

#include <algorithm> // for max
#include <limits> // for numeric_limits
#include <unordered_map> // for unordered_map

#include "../../../src/common/bitfield.h" // for LBitField32
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy
#include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator
#include "xgboost/json.h" // for Json, Object, get, String
#include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector
#include "xgboost/logging.h" // for CHECK
#include "xgboost/span.h" // for operator!=, SpanIterator, Span
#include "xgboost/tree_model.h" // for RegTree

namespace xgboost {
TEST(Predictor, PredictionCache) {
size_t constexpr kRows = 16, kCols = 4;

PredictionContainer container;
DMatrix* m;
DMatrix *m;
// Add a cache that is immediately expired.
auto add_cache = [&]() {
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
Expand Down Expand Up @@ -412,4 +418,101 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
}
}
}

void TestVectorLeafPrediction(Context const *ctx) {
std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", ctx));

size_t constexpr kRows = 5;
size_t constexpr kCols = 5;

LearnerModelParam mparam{static_cast<bst_feature_t>(kCols),
linalg::Vector<float>{{0.5}, {1}, Context::kCpuId}, 1, 3,
MultiStrategy::kMonolithic};

std::vector<std::unique_ptr<RegTree>> trees;
trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature});

std::vector<float> p_w(mparam.LeafLength(), 0.0f);
std::vector<float> l_w(mparam.LeafLength(), 1.0f);
std::vector<float> r_w(mparam.LeafLength(), 2.0f);

auto &tree = trees.front();
tree->ExpandNode(0, static_cast<bst_feature_t>(1), 2.0, true,
linalg::MakeVec(p_w.data(), p_w.size()), linalg::MakeVec(l_w.data(), l_w.size()),
linalg::MakeVec(r_w.data(), r_w.size()));
ASSERT_TRUE(tree->IsMultiTarget());
ASSERT_TRUE(mparam.IsVectorLeaf());

gbm::GBTreeModel model{&mparam, ctx};
model.CommitModel(std::move(trees), 0);

auto run_test = [&](float expected, HostDeviceVector<float> *p_data) {
{
auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols);
PredictionCacheEntry predt_cache;
cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
ASSERT_EQ(predt_cache.predictions.Size(), kRows * mparam.LeafLength());
cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1);
auto const &h_predt = predt_cache.predictions.HostVector();
for (auto v : h_predt) {
ASSERT_EQ(v, expected);
}
}

{
// inplace
PredictionCacheEntry predt_cache;
auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols);
cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
auto arr = GetArrayInterface(p_data, kRows, kCols);
std::string str;
Json::Dump(arr, &str);
auto proxy = std::shared_ptr<DMatrix>(new data::DMatrixProxy{});
dynamic_cast<data::DMatrixProxy *>(proxy.get())->SetArrayData(str.data());
cpu_predictor->InplacePredict(proxy, model, std::numeric_limits<float>::quiet_NaN(),
&predt_cache, 0, 1);
auto const &h_predt = predt_cache.predictions.HostVector();
for (auto v : h_predt) {
ASSERT_EQ(v, expected);
}
}

{
// ghist
PredictionCacheEntry predt_cache;
auto &h_data = p_data->HostVector();
// give it at least two bins, otherwise the histogram cuts only have min and max values.
for (std::size_t i = 0; i < 5; ++i) {
h_data[i] = 1.0;
}
auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols);

cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);

auto iter = NumpyArrayIterForTest{ctx, *p_data, kRows, static_cast<bst_feature_t>(kCols),
static_cast<std::size_t>(1)};
p_fmat =
std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 0, 256);

cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model);
cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1);
auto const &h_predt = predt_cache.predictions.HostVector();
// the smallest v uses the min_value from histogram cuts, which leads to a left leaf
// during prediction.
for (std::size_t i = 5; i < h_predt.size(); ++i) {
ASSERT_EQ(h_predt[i], expected) << i;
}
}
};

// go to right
HostDeviceVector<float> data(kRows * kCols, model.trees.front()->SplitCond(RegTree::kRoot) + 1.0);
run_test(2.5, &data);

// go to left
data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0);
run_test(1.5, &data);
}
} // namespace xgboost
13 changes: 11 additions & 2 deletions tests/cpp/predictor/test_predictor.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_TEST_PREDICTOR_H_
#define XGBOOST_TEST_PREDICTOR_H_

#include <xgboost/context.h> // for Context
#include <xgboost/predictor.h>
#include <string>

#include <cstddef>
#include <string>

#include "../../../src/gbm/gbtree_model.h" // for GBTreeModel
#include "../helpers.h"

namespace xgboost {
Expand Down Expand Up @@ -48,7 +55,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
PredictionCacheEntry precise_out_predictions;
predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model);
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
ASSERT_FALSE(p_dmat->PageExists<Page>());
CHECK(!p_dmat->PageExists<Page>());
}
}

Expand All @@ -69,6 +76,8 @@ void TestCategoricalPredictLeaf(StringView name);
void TestIterationRange(std::string name);

void TestSparsePrediction(float sparsity, std::string predictor);

void TestVectorLeafPrediction(Context const* ctx);
} // namespace xgboost

#endif // XGBOOST_TEST_PREDICTOR_H_

0 comments on commit b3a2141

Please sign in to comment.