Skip to content

Commit

Permalink
Use mmap for external memory. (#9282)
Browse files Browse the repository at this point in the history
- Have basic infrastructure for mmap.
- Release file write handle.
  • Loading branch information
trivialfis authored Jun 19, 2023
1 parent d8beb51 commit ee6809e
Show file tree
Hide file tree
Showing 16 changed files with 588 additions and 264 deletions.
6 changes: 3 additions & 3 deletions demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def main(tmpdir: str) -> xgboost.Booster:
missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)

# Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in
# doc for details.
# Other tree methods including ``approx``, and ``gpu_hist`` are supported. GPU
# behaves differently than CPU tree methods. See tutorial in doc for details.
booster = xgboost.train(
{"tree_method": "approx", "max_depth": 2},
{"tree_method": "hist", "max_depth": 4},
Xy,
evals=[(Xy, "Train")],
num_boost_round=10,
Expand Down
153 changes: 105 additions & 48 deletions doc/tutorials/external_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,25 @@
Using XGBoost External Memory Version
#####################################

XGBoost supports loading data from external memory using builtin data parser. And
starting from version 1.5, users can also define a custom iterator to load data in chunks.
The feature is still experimental and not yet ready for production use. In this tutorial
we will introduce both methods. Please note that training on data from external memory is
not supported by ``exact`` tree method.
When working with large datasets, training XGBoost models can be challenging as the entire
dataset needs to be loaded into memory. This can be costly and sometimes
infeasible. Staring from 1.5, users can define a custom iterator to load data in chunks
for running XGBoost algorithms. External memory can be used for both training and
prediction, but training is the primary use case and it will be our focus in this
tutorial. For prediction and evaluation, users can iterate through the data themseleves
while training requires the full dataset to be loaded into the memory.

During training, there are two different modes for external memory support available in
XGBoost, one for CPU-based algorithms like ``hist`` and ``approx``, another one for the
GPU-based training algorithm. We will introduce them in the following sections.

.. note::

Training on data from external memory is not supported by the ``exact`` tree method.

.. note::

The feature is still experimental as of 2.0. The performance is not well optimized.

*************
Data Iterator
Expand All @@ -15,8 +29,8 @@ Data Iterator
Starting from XGBoost 1.5, users can define their own data loader using Python or C
interface. There are some examples in the ``demo`` directory for quick start. This is a
generalized version of text input external memory, where users no longer need to prepare a
text file that XGBoost recognizes. To enable the feature, user need to define a data
iterator with 2 class methods ``next`` and ``reset`` then pass it into ``DMatrix``
text file that XGBoost recognizes. To enable the feature, users need to define a data
iterator with 2 class methods: ``next`` and ``reset``, then pass it into the ``DMatrix``
constructor.

.. code-block:: python
Expand Down Expand Up @@ -60,20 +74,96 @@ constructor.
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
# as noted in following sections.
booster = xgboost.train({"tree_method": "approx"}, Xy)
booster = xgboost.train({"tree_method": "hist"}, Xy)
The above snippet is a simplified version of :ref:`sphx_glr_python_examples_external_memory.py`.
For an example in C, please see ``demo/c-api/external-memory/``. The iterator is the
common interface for using external memory with XGBoost, you can pass the resulting
``DMatrix`` object for training, prediction, and evaluation.

It is important to set the batch size based on the memory available. A good starting point
is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not*
recommended to set small batch sizes like 32 samples per batch, as this can seriously hurt
performance in gradient boosting.

***********
CPU Version
***********

In the previous section, we demonstrated how to train a tree-based model using the
``hist`` tree method on a CPU. This method involves iterating through data batches stored
in a cache during tree construction. For optimal performance, we recommend using the
``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer of tree
nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy
requires XGBoost to iterate over the data set for each tree node, resulting in slower
performance.

If external memory is used, the performance of CPU training is limited by IO
(input/output) speed. This means that the disk IO speed primarily determines the training
speed. During benchmarking, we used an NVMe connected to a PCIe-4 slot, other types of
storage can be too slow for practical usage. In addition, your system may perform caching
to reduce the overhead of file reading.

**********************************
GPU Version (GPU Hist tree method)
**********************************

External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set to
``gpu_hist``). However, the algorithm used for GPU is different from the one used for
CPU. When training on a CPU, the tree method iterates through all batches from external
memory for each step of the tree construction algorithm. On the other hand, the GPU
algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall
memory usage, users can utilize subsampling. The good news is that the GPU hist tree
method supports gradient-based sampling, enabling users to set a low sampling rate without
compromising accuracy.

.. code-block:: python
param = {
...
'subsample': 0.2,
'sampling_method': 'gradient_based',
}
For more information about the sampling algorithm and its use in external memory training,
see `this paper <https://arxiv.org/abs/2005.09148>`_.

.. warning::

When GPU is running out of memory during iteration on external memory, user might
recieve a segfault instead of an OOM exception.

*******
Remarks
*******

When using external memory with XBGoost, data is divided into smaller chunks so that only
a fraction of it needs to be stored in memory at any given time. It's important to note
that this method only applies to the predictor data (``X``), while other data, like labels
and internal runtime structures are concatenated. This means that memory reduction is most
effective when dealing with wide datasets where ``X`` is larger compared to other data
like ``y``, while it has little impact on slim datasets.

Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's
worth noting that most tests have been conducted on Linux distributions.

Another important point to keep in mind is that creating the initial cache for XGBoost may
take some time. The interface to external memory is through custom iterators, which may or
may not be thread-safe. Therefore, initialization is performed sequentially.

The above snippet is a simplified version of ``demo/guide-python/external_memory.py``. For
an example in C, please see ``demo/c-api/external-memory/``.

****************
Text File Inputs
****************

There is no big difference between using external memory version and in-memory version.
The only difference is the filename format.
This is the original form of external memory support, users are encouraged to use custom
data iterator instead. There is no big difference between using external memory version of
text input and the in-memory version. The only difference is the filename format.

The external memory version takes in the following `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format:
The external memory version takes in the following `URI
<https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format:

.. code-block:: none
Expand All @@ -91,9 +181,8 @@ To load from csv files, use the following syntax:
where ``label_column`` should point to the csv column acting as the label.

To provide a simple example for illustration, extracting the code from
`demo/guide-python/external_memory.py <https://github.com/dmlc/xgboost/blob/master/demo/guide-python/external_memory.py>`_. If
you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSVM format, the external memory support can be enabled by:
If you have a dataset stored in a file similar to ``demo/data/agaricus.txt.train`` with LIBSVM
format, the external memory support can be enabled by:

.. code-block:: python
Expand All @@ -104,35 +193,3 @@ XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to
more notes about text input formats, see :doc:`/tutorials/input_format`.

For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.


**********************************
GPU Version (GPU Hist tree method)
**********************************
External memory is supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``).

If you are still getting out-of-memory errors after enabling external memory, try subsampling the
data to further reduce GPU memory usage:

.. code-block:: python
param = {
...
'subsample': 0.1,
'sampling_method': 'gradient_based',
}
For more information, see `this paper <https://arxiv.org/abs/2005.09148>`_. Internally
the tree method still concatenate all the chunks into 1 final histogram index due to
performance reason, but in compressed format. So its scalability has an upper bound but
still has lower memory cost in general.

***********
CPU Version
***********

For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use
``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow,
with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few
iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each
tree node.
8 changes: 5 additions & 3 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def __init__(
X: Sequence,
y: Sequence,
w: Optional[Sequence],
cache: Optional[str] = "./",
cache: Optional[str],
) -> None:
assert len(X) == len(y)
self.X = X
self.y = y
self.w = w
self.it = 0
super().__init__(cache)
super().__init__(cache_prefix=cache)

def next(self, input_data: Callable) -> int:
if self.it == len(self.X):
Expand Down Expand Up @@ -347,7 +347,9 @@ def get_external_dmat(self) -> xgb.DMatrix:
if w is not None:
weight.append(w)

it = IteratorForTest(predictor, response, weight if weight else None)
it = IteratorForTest(
predictor, response, weight if weight else None, cache="cache"
)
return xgb.DMatrix(it)

def __repr__(self) -> str:
Expand Down
74 changes: 42 additions & 32 deletions rabit/include/rabit/internal/io.h
Original file line number Diff line number Diff line change
@@ -1,73 +1,83 @@
/*!
* Copyright (c) 2014-2019 by Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file io.h
* \brief utilities with different serializable implementations
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_IO_H_
#define RABIT_INTERNAL_IO_H_
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>

#include <algorithm>
#include <numeric>
#include <cstddef> // for size_t
#include <cstdio>
#include <cstring> // for memcpy
#include <limits>
#include <numeric>
#include <string>
#include <vector>

#include "rabit/internal/utils.h"
#include "rabit/serializable.h"

namespace rabit {
namespace utils {
/*! \brief re-use definition of dmlc::SeekStream */
using SeekStream = dmlc::SeekStream;
/*! \brief fixed size memory buffer */
/**
* @brief Fixed size memory buffer as a stream.
*/
struct MemoryFixSizeBuffer : public SeekStream {
public:
// similar to SEEK_END in libc
static size_t constexpr kSeekEnd = std::numeric_limits<size_t>::max();
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();

protected:
MemoryFixSizeBuffer() = default;

public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
/**
* @brief Ctor
*
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
* @param buffer_size Size of the source buffer
*/
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
~MemoryFixSizeBuffer() override = default;
size_t Read(void *ptr, size_t size) override {
size_t nread = std::min(buffer_size_ - curr_ptr_, size);

std::size_t Read(void *ptr, std::size_t size) override {
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, size_t size) override {
void Write(const void *ptr, std::size_t size) override {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
CHECK_LE(curr_ptr_ + size, buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(size_t pos) override {
void Seek(std::size_t pos) override {
if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_;
} else {
curr_ptr_ = static_cast<size_t>(pos);
curr_ptr_ = static_cast<std::size_t>(pos);
}
}
size_t Tell() override {
return curr_ptr_;
}
virtual bool AtEnd() const {
return curr_ptr_ == buffer_size_;
}
/**
* @brief Current position in the buffer (stream).
*/
std::size_t Tell() override { return curr_ptr_; }
virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }

private:
protected:
/*! \brief in memory buffer */
char *p_buffer_;
char *p_buffer_{nullptr};
/*! \brief current pointer */
size_t buffer_size_;
std::size_t buffer_size_{0};
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryFixSizeBuffer
std::size_t curr_ptr_{0};
};

/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public SeekStream {
Expand Down
Loading

0 comments on commit ee6809e

Please sign in to comment.