Skip to content

Commit

Permalink
Drop binary format for memory snapshot.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 16, 2020
1 parent 749364f commit fcd8232
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 103 deletions.
41 changes: 10 additions & 31 deletions doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ open format that can be easily reused. The support for binary format will be co
the future until JSON format is no-longer experimental and has satisfying performance.
This tutorial aims to share some basic insights into the JSON serialisation method used in
XGBoost. Without explicitly mentioned, the following sections assume you are using the
experimental JSON format, which can be enabled by passing
``enable_experimental_json_serialization=True`` as training parameter, or provide the file
name with ``.json`` as file extension when saving/loading model:
``booster.save_model('model.json')``. More details below.
JSON format, which can be enabled by providing the file name with ``.json`` as file
extension when saving/loading model: ``booster.save_model('model.json')``. More details
below.

Before we get started, XGBoost is a gradient boosting library with focus on tree model,
which means inside XGBoost, there are 2 distinct parts:
Expand Down Expand Up @@ -66,26 +65,7 @@ a filename with ``.json`` as file extension:
xgb.save(bst, 'model_file_name.json')
To use JSON to store memory snapshots, add ``enable_experimental_json_serialization`` as a training
parameter. In Python this can be done by:

.. code-block:: python
bst = xgboost.train({'enable_experimental_json_serialization': True}, dtrain)
with open('filename', 'wb') as fd:
pickle.dump(bst, fd)
Notice the ``filename`` is for Python intrinsic function ``open``, not for XGBoost. Hence
parameter ``enable_experimental_json_serialization`` is required to enable JSON format.

Similarly, in the R package, add ``enable_experimental_json_serialization`` to the training
parameter:

.. code-block:: r
params <- list(enable_experimental_json_serialization = TRUE, ...)
bst <- xgboost.train(params, dtrain, nrounds = 10)
saveRDS(bst, 'filename.rds')
While for memory snapshot, JSON is the default starting from xgboost 1.3.

***************************************************************
A note on backward compatibility of models and memory snapshots
Expand All @@ -110,11 +90,11 @@ Custom objective and metric
***************************

XGBoost accepts user provided objective and metric functions as an extension. These
functions are not saved in model file as they are language dependent feature. With
functions are not saved in model file as they are language dependent features. With
Python, user can pickle the model to include these functions in saved binary. One
drawback is, the output from pickle is not a stable serialization format and doesn't work
on different Python version or XGBoost version, not to mention different language
environment. Another way to workaround this limitation is to provide these functions
on different Python version nor XGBoost version, not to mention different language
environments. Another way to workaround this limitation is to provide these functions
again after the model is loaded. If the customized function is useful, please consider
making a PR for implementing it inside XGBoost, this way we can have your functions
working with different language bindings.
Expand All @@ -128,9 +108,9 @@ models are valuable. One way to restore it in the future is to load it back wit
specific version of Python and XGBoost, export the model by calling `save_model`. To help
easing the mitigation, we created a simple script for converting pickled XGBoost 0.90
Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script
suits simple use cases, and it's advised not to use pickle when stability is needed.
It's located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See
comments in the script for more details.
suits simple use cases, and it's advised not to use pickle when stability is needed. It's
located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See comments in
the script for more details.

A similar procedure may be used to recover the model persisted in an old RDS file. In R, you are
able to install an older version of XGBoost using the ``remotes`` package:
Expand Down Expand Up @@ -172,7 +152,6 @@ Will print out something similiar to (not actual output as it's too long for dem
{
"Learner": {
"generic_parameter": {
"enable_experimental_json_serialization": "0",
"gpu_id": "0",
"gpu_page_size": "0",
"n_jobs": "0",
Expand Down
5 changes: 0 additions & 5 deletions include/xgboost/generic_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
bool fail_on_invalid_gpu_id {false};
// gpu page size in external memory mode, 0 means using the default.
size_t gpu_page_size;
bool enable_experimental_json_serialization {true};
bool validate_parameters {false};

void CheckDeprecated() {
Expand Down Expand Up @@ -74,10 +73,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.set_default(0)
.set_lower_bound(0)
.describe("GPU page size when running in external memory mode.");
DMLC_DECLARE_FIELD(enable_experimental_json_serialization)
.set_default(true)
.describe("Enable using JSON for memory serialization (Python Pickle, "
"rabit checkpoints etc.).");
DMLC_DECLARE_FIELD(validate_parameters)
.set_default(false)
.describe("Enable checking whether parameters are used or not.");
Expand Down
44 changes: 10 additions & 34 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -874,40 +874,16 @@ class LearnerIO : public LearnerConfiguration {
}

void Save(dmlc::Stream* fo) const override {
if (generic_parameters_.enable_experimental_json_serialization) {
Json memory_snapshot{Object()};
memory_snapshot["Model"] = Object();
auto &model = memory_snapshot["Model"];
this->SaveModel(&model);
memory_snapshot["Config"] = Object();
auto &config = memory_snapshot["Config"];
this->SaveConfig(&config);
std::string out_str;
Json::Dump(memory_snapshot, &out_str);
fo->Write(out_str.c_str(), out_str.size());
} else {
std::string binary_buf;
common::MemoryBufferStream s(&binary_buf);
this->SaveModel(&s);
Json config{ Object() };
// Do not use std::size_t as it's not portable.
int64_t const json_offset = binary_buf.size();
this->SaveConfig(&config);
std::string config_str;
Json::Dump(config, &config_str);
// concatonate the model and config at final output, it's a temporary solution for
// continuing support for binary model format
fo->Write(&serialisation_header_[0], serialisation_header_.size());
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&json_offset, sizeof(json_offset));
} else {
auto x = json_offset;
dmlc::ByteSwap(&x, sizeof(x), 1);
fo->Write(&x, sizeof(json_offset));
}
fo->Write(&binary_buf[0], binary_buf.size());
fo->Write(&config_str[0], config_str.size());
}
Json memory_snapshot{Object()};
memory_snapshot["Model"] = Object();
auto &model = memory_snapshot["Model"];
this->SaveModel(&model);
memory_snapshot["Config"] = Object();
auto &config = memory_snapshot["Config"];
this->SaveConfig(&config);
std::string out_str;
Json::Dump(memory_snapshot, &out_str);
fo->Write(out_str.c_str(), out_str.size());
}

void Load(dmlc::Stream* fi) override {
Expand Down
8 changes: 0 additions & 8 deletions tests/python/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,9 @@ def run_model_pickling(self, xgb_params):
if os.path.exists(filename):
os.remove(filename)

def test_model_pickling_binary(self):
params = {
'nthread': 1,
'tree_method': 'hist'
}
self.run_model_pickling(params)

def test_model_pickling_json(self):
params = {
'nthread': 1,
'tree_method': 'hist',
'enable_experimental_json_serialization': True
}
self.run_model_pickling(params)
28 changes: 3 additions & 25 deletions tests/python/test_training_continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class TestTrainingContinuation:
num_parallel_tree = 3

def generate_parameters(self, use_json):
def generate_parameters(self):
xgb_params_01_binary = {
'nthread': 1,
}
Expand All @@ -24,13 +24,6 @@ def generate_parameters(self, use_json):
'num_class': 5,
'num_parallel_tree': self.num_parallel_tree
}
if use_json:
xgb_params_01_binary[
'enable_experimental_json_serialization'] = True
xgb_params_02_binary[
'enable_experimental_json_serialization'] = True
xgb_params_03_binary[
'enable_experimental_json_serialization'] = True

return [
xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary
Expand Down Expand Up @@ -136,31 +129,16 @@ def run_training_continuation(self, xgb_params_01, xgb_params_02,
ntree_limit=gbdt_05.best_ntree_limit)
np.testing.assert_almost_equal(res1, res2)

@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_binary(self):
params = self.generate_parameters(False)
self.run_training_continuation(params[0], params[1], params[2])

@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_json(self):
params = self.generate_parameters(True)
for p in params:
p['enable_experimental_json_serialization'] = True
self.run_training_continuation(params[0], params[1], params[2])

@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_updaters_binary(self):
updaters = 'grow_colmaker,prune,refresh'
params = self.generate_parameters(False)
for p in params:
p['updater'] = updaters
params = self.generate_parameters()
self.run_training_continuation(params[0], params[1], params[2])

@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_updaters_json(self):
# Picked up from R tests.
updaters = 'grow_colmaker,prune,refresh'
params = self.generate_parameters(True)
params = self.generate_parameters()
for p in params:
p['updater'] = updaters
self.run_training_continuation(params[0], params[1], params[2])

0 comments on commit fcd8232

Please sign in to comment.