Skip to content

Commit

Permalink
Fix categorical data with external memory. (dmlc#10433)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 18, 2024
1 parent 63b49f3 commit 97c6033
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
2 changes: 1 addition & 1 deletion demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_batches(
class Iterator(xgboost.DataIter):
"""A custom iterator for loading files in batches."""

def __init__(self, file_paths: List[Tuple[str, str]]):
def __init__(self, file_paths: List[Tuple[str, str]]) -> None:
self._file_paths = file_paths
self._it = 0
# XGBoost will generate some cache files under current directory with the prefix
Expand Down
3 changes: 1 addition & 2 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2024 by XGBoost Contributors
* Copyright 2017-2024, XGBoost Contributors
* \file hist_util.h
* \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen
Expand All @@ -11,7 +11,6 @@
#include <cstdint> // for uint32_t
#include <limits>
#include <map>
#include <memory>
#include <utility>
#include <vector>

Expand Down
5 changes: 2 additions & 3 deletions src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
#include "gradient_index.h"

#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // for forward
Expand Down Expand Up @@ -126,8 +125,8 @@ INSTANTIATION_PUSH(data::ColumnarAdapterBatch)
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) {
// Must resize instead of allocating a new one. This function is called everytime a
// new batch is pushed, and we grow the size accordingly without loosing the data the
// previous batches.
// new batch is pushed, and we grow the size accordingly without loosing the data in
// the previous batches.
using T = decltype(t);
std::size_t n_bytes = sizeof(T) * n_index;
CHECK_GE(n_bytes, this->data.size());
Expand Down
13 changes: 12 additions & 1 deletion src/data/histogram_cut_format.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
Expand All @@ -23,6 +23,15 @@ inline bool ReadHistogramCuts(common::HistogramCuts *cuts, common::AlignedResour
if (!common::ReadVec(fi, &cuts->min_vals_.HostVector())) {
return false;
}
bool has_cat{false};
if (!fi->Read(&has_cat)) {
return false;
}
decltype(cuts->MaxCategory()) max_cat{0};
if (!fi->Read(&max_cat)) {
return false;
}
cuts->SetCategorical(has_cat, max_cat);
return true;
}

Expand All @@ -32,6 +41,8 @@ inline std::size_t WriteHistogramCuts(common::HistogramCuts const &cuts,
bytes += common::WriteVec(fo, cuts.Values());
bytes += common::WriteVec(fo, cuts.Ptrs());
bytes += common::WriteVec(fo, cuts.MinValues());
bytes += fo->Write(cuts.HasCategorical());
bytes += fo->Write(cuts.MaxCategory());
return bytes;
}
} // namespace xgboost::data
Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ def test_single_batch(tree_method: str = "approx") -> None:
assert from_np.get_dump() == from_it.get_dump()


def test_with_cat_single() -> None:
X, y = tm.make_categorical(
n_samples=128, n_features=3, n_categories=6, onehot=False
)
Xy = xgb.DMatrix(SingleBatch(data=X, label=y), enable_categorical=True)
from_it = xgb.train({}, Xy, num_boost_round=3)

Xy = xgb.DMatrix(X, y, enable_categorical=True)
from_Xy = xgb.train({}, Xy, num_boost_round=3)

jit = from_it.save_raw(raw_format="json")
jxy = from_Xy.save_raw(raw_format="json")
assert jit == jxy


def run_data_iterator(
n_samples_per_batch: int,
n_features: int,
Expand Down

0 comments on commit 97c6033

Please sign in to comment.