Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve order of saved updaters config. #9355

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions R-package/tests/testthat/test_model_compatibility.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,32 +76,20 @@ test_that("Models from previous versions of XGBoost can be loaded", {
name <- m[3]
is_rds <- endsWith(model_file, '.rds')
is_json <- endsWith(model_file, '.json')

cpp_warning <- capture.output({
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
booster <- readRDS(model_file)
expect_warning(predict(booster, newdata = pred_data))
# Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
booster <- readRDS(model_file)
expect_warning(predict(booster, newdata = pred_data))
booster <- readRDS(model_file)
expect_warning(run_booster_check(booster, name))
} else {
if (is_rds) {
booster <- readRDS(model_file)
expect_warning(run_booster_check(booster, name))
} else {
if (is_rds) {
booster <- readRDS(model_file)
} else {
booster <- xgb.load(model_file)
}
predict(booster, newdata = pred_data)
run_booster_check(booster, name)
booster <- xgb.load(model_file)
}
})
cpp_warning <- paste0(cpp_warning, collapse = ' ')
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') >= 0) {
# Expect a C++ warning when a model is loaded from RDS and it was generated by old XGBoost`
m <- grepl(paste0('.*If you are loading a serialized model ',
'\\(like pickle in Python, RDS in R\\).*',
'for more details about differences between ',
'saving model and serializing.*'), cpp_warning, perl = TRUE)
expect_true(length(m) > 0 && all(m))
predict(booster, newdata = pred_data)
run_booster_check(booster, name)
}
})
})
28 changes: 28 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,33 @@ inline void MaxFeatureSize(std::uint64_t n_features) {
constexpr StringView InplacePredictProxy() {
return "Inplace predict accepts only DMatrixProxy as input.";
}

inline void MaxSampleSize(std::size_t n) {
LOG(FATAL) << "Sample size too large for the current updater. Maximum number of samples:" << n
<< ". Consider using a different updater or tree_method.";
}

constexpr StringView OldSerialization() {
return R"doc(If you are loading a serialized model (like pickle in Python, RDS in R) or
configuration generated by an older version of XGBoost, please export the model by calling
`Booster.save_model` from that version first, then load it back in current version. See:

https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html

for more details about differences between saving model and serializing.
)doc";
}

inline void WarnOldSerialization() {
// Display it once is enough. Otherwise this can be really verbose in distributed
// environments.
static thread_local bool logged{false};
if (logged) {
return;
}

LOG(WARNING) << OldSerialization();
logged = true;
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
14 changes: 10 additions & 4 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
#include "xgboost/c_api.h"
#include "xgboost/data.h"

namespace xgboost {
namespace data {
namespace xgboost::data {
MetaInfo& SimpleDMatrix::Info() { return info_; }

const MetaInfo& SimpleDMatrix::Info() const { return info_; }
Expand Down Expand Up @@ -97,6 +96,10 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches(Context const* ctx) {
// column page doesn't exist, generate it
if (!column_page_) {
auto n = std::numeric_limits<decltype(Entry::index)>::max();
if (this->sparse_page_->Size() > n) {
error::MaxSampleSize(n);
}
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
}
auto begin_iter = BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
Expand All @@ -106,6 +109,10 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches(Context const* ctx) {
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches(Context const* ctx) {
// Sorted column page doesn't exist, generate it
if (!sorted_column_page_) {
auto n = std::numeric_limits<decltype(Entry::index)>::max();
if (this->sparse_page_->Size() > n) {
error::MaxSampleSize(n);
}
sorted_column_page_.reset(
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
sorted_column_page_->SortRows(ctx->Threads());
Expand Down Expand Up @@ -427,5 +434,4 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i

fmat_ctx_ = ctx;
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
42 changes: 28 additions & 14 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <vector>

#include "../common/common.h"
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
#include "../common/error_msg.h" // for UnknownDevice, WarnOldSerialization, InplacePredictProxy
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../common/timer.h"
Expand Down Expand Up @@ -391,19 +391,32 @@ void GBTree::LoadConfig(Json const& in) {
LOG(WARNING) << msg << " Changing `tree_method` to `hist`.";
}

auto const& j_updaters = get<Object const>(in["updater"]);
std::vector<Json> updater_seq;
if (IsA<Object>(in["updater"])) {
// before 2.0
error::WarnOldSerialization();
for (auto const& kv : get<Object const>(in["updater"])) {
auto name = kv.first;
auto config = kv.second;
config["name"] = name;
updater_seq.push_back(config);
}
} else {
// after 2.0
auto const& j_updaters = get<Array const>(in["updater"]);
updater_seq = j_updaters;
}

updaters_.clear();

for (auto const& kv : j_updaters) {
auto name = kv.first;
for (auto const& config : updater_seq) {
auto name = get<String>(config["name"]);
if (n_gpus == 0 && name == "grow_gpu_hist") {
name = "grow_quantile_histmaker";
LOG(WARNING) << "Changing updater from `grow_gpu_hist` to `grow_quantile_histmaker`.";
}
std::unique_ptr<TreeUpdater> up{
TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task)};
up->LoadConfig(kv.second);
updaters_.push_back(std::move(up));
updaters_.emplace_back(TreeUpdater::Create(name, ctx_, &model_.learner_model_param->task));
updaters_.back()->LoadConfig(config);
}

specified_updater_ = get<Boolean>(in["specified_updater"]);
Expand All @@ -425,13 +438,14 @@ void GBTree::SaveConfig(Json* p_out) const {
// language binding doesn't need to know about the forest size.
out["gbtree_model_param"] = ToJson(model_.param);

out["updater"] = Object();
out["updater"] = Array{};
auto& j_updaters = get<Array>(out["updater"]);

auto& j_updaters = out["updater"];
for (auto const& up : updaters_) {
j_updaters[up->Name()] = Object();
auto& j_up = j_updaters[up->Name()];
up->SaveConfig(&j_up);
for (auto const& up : this->updaters_) {
Json up_config{Object{}};
up_config["name"] = String{up->Name()};
up->SaveConfig(&up_config);
j_updaters.emplace_back(up_config);
}
out["specified_updater"] = Boolean{specified_updater_};
}
Expand Down
25 changes: 5 additions & 20 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_...
#include "common/common.h" // for ToString, Split
#include "common/error_msg.h" // for MaxFeatureSize
#include "common/error_msg.h" // for MaxFeatureSize, WarnOldSerialization
#include "common/io.h" // for PeekableInStream, ReadAll, FixedSizeStream, Mem...
#include "common/observer.h" // for TrainingObserver
#include "common/random.h" // for GlobalRandom
Expand Down Expand Up @@ -357,21 +357,6 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
using LearnerAPIThreadLocalStore =
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;

namespace {
StringView ModelMsg() {
return StringView{
R"doc(
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
older XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See:

https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html

for more details about differences between saving model and serializing.
)doc"};
}
} // anonymous namespace

class LearnerConfiguration : public Learner {
private:
std::mutex config_lock_;
Expand Down Expand Up @@ -531,7 +516,7 @@ class LearnerConfiguration : public Learner {
}

if (!Version::Same(origin_version)) {
LOG(WARNING) << ModelMsg();
error::WarnOldSerialization();
return; // skip configuration if version is not matched
}

Expand Down Expand Up @@ -562,7 +547,7 @@ class LearnerConfiguration : public Learner {
for (size_t i = 0; i < n_metrics; ++i) {
auto old_serialization = IsA<String>(j_metrics[i]);
if (old_serialization) {
LOG(WARNING) << ModelMsg();
error::WarnOldSerialization();
metric_names_[i] = get<String>(j_metrics[i]);
} else {
metric_names_[i] = get<String>(j_metrics[i]["name"]);
Expand Down Expand Up @@ -1173,7 +1158,7 @@ class LearnerIO : public LearnerConfiguration {
Json memory_snapshot;
if (header[1] == '"') {
memory_snapshot = Json::Load(StringView{buffer});
LOG(WARNING) << ModelMsg();
error::WarnOldSerialization();
} else if (std::isalpha(header[1])) {
memory_snapshot = Json::Load(StringView{buffer}, std::ios::binary);
} else {
Expand All @@ -1192,7 +1177,7 @@ class LearnerIO : public LearnerConfiguration {
header.resize(serialisation_header_.size());
CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size());
// Avoid printing the content in loaded header, which might be random binary code.
CHECK(header == serialisation_header_) << ModelMsg();
CHECK(header == serialisation_header_) << error::OldSerialization();
int64_t sz {-1};
CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz));
if (!DMLC_IO_NO_ENDIAN_SWAP) {
Expand Down
40 changes: 30 additions & 10 deletions tests/cpp/gbm/test_gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,32 +174,52 @@ TEST(GBTree, JsonIO) {
Context ctx;
LearnerModelParam mparam{MakeMP(kCols, .5, 1)};

std::unique_ptr<GradientBooster> gbm {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &ctx) };
std::unique_ptr<GradientBooster> gbm{
CreateTrainedGBM("gbtree", Args{{"tree_method", "exact"}, {"default_direction", "left"}},
kRows, kCols, &mparam, &ctx)};

Json model {Object()};
Json model{Object()};
model["model"] = Object();
auto& j_model = model["model"];
auto j_model = model["model"];

model["config"] = Object();
auto& j_param = model["config"];
auto j_config = model["config"];

gbm->SaveModel(&j_model);
gbm->SaveConfig(&j_param);
gbm->SaveConfig(&j_config);

std::string model_str;
Json::Dump(model, &model_str);

model = Json::Load({model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
j_model = model["model"];
j_config = model["config"];
ASSERT_EQ(get<String>(j_model["name"]), "gbtree");

auto const& gbtree_model = model["model"]["model"];
auto gbtree_model = j_model["model"];
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1ul);
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1ul);

auto j_train_param = model["config"]["gbtree_model_param"];
auto j_train_param = j_config["gbtree_model_param"];
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");

auto check_config = [](Json j_up_config) {
auto colmaker = get<Array const>(j_up_config).front();
auto pruner = get<Array const>(j_up_config).back();
ASSERT_EQ(get<String const>(colmaker["name"]), "grow_colmaker");
ASSERT_EQ(get<String const>(pruner["name"]), "prune");
ASSERT_EQ(get<String const>(colmaker["colmaker_train_param"]["default_direction"]), "left");
};
check_config(j_config["updater"]);

std::unique_ptr<GradientBooster> loaded(gbm::GBTree::Create("gbtree", &ctx, &mparam));
loaded->LoadModel(j_model);
loaded->LoadConfig(j_config);

// roundtrip test
Json j_config_rt{Object{}};
loaded->SaveConfig(&j_config_rt);
check_config(j_config_rt["updater"]);
}

TEST(Dart, JsonIO) {
Expand Down