From d87affbf35aaf87e5c8e6a799395be6e7629f224 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sun, 1 May 2022 18:53:16 +0800 Subject: [PATCH] Implement multi-target for hist. Initial commit. Predictor. Compile. fixes. Cleanup. Moving code around. Start working on cat features. Start working on model IO. Fix. Revert. cleanup. Rebase. Reverse cleanup. rename. Fix rebase. small cleanup. inc Merge it into reg tree. Strategy. Extract the cat matrix. Use array in predictor. Use array in scalar. Merge two kernels. QDM. inplace predict. cleanup. naming. cleanup. cleanup. sampler. copy. cleanup. compile test. Hide the tree. Hide from the partitioner. Hide init root. layer to trees. check. Remove old sampling func. leaf partition. use linalg. remove grad stats. ro5 reverse. Don't support prediction cache for now. col sampler. Cleanup. Cleanup. Cleanup histogram. t Cleanup evaluation. ic. Cleanup. start working on io. is valid. basic io. dispatch. Basic IO. Cleanup node sum. cleanup. Extract the updater. Merge it into quantile hist. cleanup. Cleanup. restore checks. Cleanup. remove num_target. fix tests. Fix. fixes. Type deduction. R package. Predict leaf. Predict leaf. cleanup. Add a test to sampling. check. cleanup. cleanup. parallel. Cleanup Fix root. column-major. fewer right. Cleanup. Initial work on merging the updaters. Fix. Merge update tree. Consistent naming. HD. Unify sampling. Fix build. Fix build. CUDA build. Fix GPU SHAP tests. fix. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + demo/guide-python/multioutput_regression.py | 17 +- include/xgboost/base.h | 30 +- include/xgboost/learner.h | 33 +- include/xgboost/linalg.h | 126 ++- include/xgboost/multi_target_tree_model.h | 125 +++ include/xgboost/tree_model.h | 160 +++- include/xgboost/tree_updater.h | 1 + src/c_api/c_api.cc | 7 +- src/c_api/c_api.cu | 5 +- src/c_api/c_api_utils.h | 8 +- src/collective/rabit_communicator.h | 2 +- src/common/hist_util.cc | 37 +- src/common/hist_util.h | 9 +- src/common/host_device_vector.cc | 1 - src/common/host_device_vector.cu | 2 +- src/common/partition_builder.h | 184 ++--- src/common/quantile.h | 13 - src/common/stats.cu | 17 +- src/common/threading_utils.h | 2 +- src/data/iterative_dmatrix.cc | 3 +- src/data/simple_dmatrix.cc | 7 +- src/gbm/gbtree.cc | 57 +- src/gbm/gbtree.h | 28 +- src/gbm/gbtree_model.cc | 4 +- src/gbm/gbtree_model.h | 7 +- src/learner.cc | 104 ++- src/predictor/cpu_predictor.cc | 284 ++++--- src/predictor/gpu_predictor.cu | 21 +- src/predictor/predict_fn.h | 21 +- src/predictor/predictor.cc | 15 +- src/tree/common_row_partitioner.h | 88 ++- src/tree/fit_stump.cc | 6 +- src/tree/hist/evaluate_splits.h | 235 +++++- src/tree/hist/expand_entry.h | 97 ++- src/tree/hist/histogram.h | 171 ++--- src/tree/hist/sampler.h | 109 +++ src/tree/io_utils.h | 48 ++ src/tree/multi_target_tree_model.cc | 188 +++++ src/tree/param.h | 49 +- src/tree/tree_model.cc | 204 +++-- src/tree/updater_approx.cc | 50 +- src/tree/updater_gpu_hist.cu | 10 +- src/tree/updater_quantile_hist.cc | 801 +++++++++++++------- src/tree/updater_quantile_hist.h | 178 +---- tests/cpp/common/test_linalg.cc | 65 +- tests/cpp/common/test_linalg.cu | 8 +- tests/cpp/common/test_stats.cc | 14 +- tests/cpp/common/test_stats.cu | 10 +- tests/cpp/data/test_simple_dmatrix.cc | 6 +- tests/cpp/gbm/test_gbtree.cc | 2 +- tests/cpp/helpers.cc | 20 +- tests/cpp/helpers.h | 14 +- tests/cpp/metric/test_auc.cc | 2 +- tests/cpp/metric/test_elementwise_metric.cc | 2 +- tests/cpp/predictor/test_gpu_predictor.cu | 6 +- tests/cpp/test_learner.cc | 2 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 3 +- tests/cpp/tree/hist/test_histogram.cc | 30 +- tests/cpp/tree/hist/test_sampler.cc | 57 ++ tests/cpp/tree/test_approx.cc | 6 +- tests/cpp/tree/test_quantile_hist.cc | 3 +- tests/python/test_model_compatibility.py | 4 +- tests/python/test_with_sklearn.py | 2 +- 65 files changed, 2570 insertions(+), 1252 deletions(-) create mode 100644 include/xgboost/multi_target_tree_model.h create mode 100644 src/tree/hist/sampler.h create mode 100644 src/tree/io_utils.h create mode 100644 src/tree/multi_target_tree_model.cc create mode 100644 tests/cpp/tree/hist/test_sampler.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 6d9113ed55de..168bcc409777 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -59,6 +59,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/fit_stump.o \ $(PKGROOT)/src/tree/tree_model.o \ $(PKGROOT)/src/tree/tree_updater.o \ + $(PKGROOT)/src/tree/multi_target_tree_model.o \ $(PKGROOT)/src/tree/updater_approx.o \ $(PKGROOT)/src/tree/updater_colmaker.o \ $(PKGROOT)/src/tree/updater_prune.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 1914384e8ad4..113b9f27830f 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -58,6 +58,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/param.o \ $(PKGROOT)/src/tree/fit_stump.o \ $(PKGROOT)/src/tree/tree_model.o \ + $(PKGROOT)/src/tree/multi_target_tree_model.o \ $(PKGROOT)/src/tree/tree_updater.o \ $(PKGROOT)/src/tree/updater_approx.o \ $(PKGROOT)/src/tree/updater_colmaker.o \ diff --git a/demo/guide-python/multioutput_regression.py b/demo/guide-python/multioutput_regression.py index 375377e4e4b5..018f7e6445dc 100644 --- a/demo/guide-python/multioutput_regression.py +++ b/demo/guide-python/multioutput_regression.py @@ -44,10 +44,19 @@ def rmse_model(plot_result: bool): """Draw a circle with 2-dim coordinate as target variables.""" X, y = gen_circle() # Train a regressor on it - reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64) + reg = xgb.XGBRegressor( + tree_method="hist", + n_estimators=16, + n_jobs=16, + max_depth=8, + multi_strategy="mono", + subsample=0.6, + ) reg.fit(X, y, eval_set=[(X, y)]) + # reg.save_model("model.json") y_predt = reg.predict(X) + # print("y_predt:", y_predt, y) if plot_result: plot_predt(y, y_predt, "multi") @@ -81,13 +90,15 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: X, y = gen_circle() Xy = xgb.DMatrix(X, y) results: Dict[str, Dict[str, List[float]]] = {} - # Make sure the `num_target` is passed to XGBoost when custom objective is used. + # Make sure the `num_class` is passed to XGBoost when custom objective is used. # When builtin objective is used, XGBoost can figure out the number of targets # automatically. booster = xgb.train( { "tree_method": "hist", - "num_target": y.shape[1], + "num_class": y.shape[1], + "multi_strategy": "mono", + "objective": "reg:squarederror", # fixme }, dtrain=Xy, num_boost_round=100, diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 34312223c0cb..4de191a0e79e 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015-2023 by Contributors * \file base.h * \brief defines configuration macros of xgboost. */ @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -110,19 +111,19 @@ namespace xgboost { /*! \brief unsigned integer type used for feature index. */ -using bst_uint = uint32_t; // NOLINT +using bst_uint = std::uint32_t; // NOLINT /*! \brief integer type. */ -using bst_int = int32_t; // NOLINT +using bst_int = std::int32_t; // NOLINT /*! \brief unsigned long integers */ -using bst_ulong = uint64_t; // NOLINT +using bst_ulong = std::uint64_t; // NOLINT /*! \brief float type, used for storing statistics */ using bst_float = float; // NOLINT /*! \brief Categorical value type. */ -using bst_cat_t = int32_t; // NOLINT +using bst_cat_t = std::int32_t; // NOLINT /*! \brief Type for data column (feature) index. */ -using bst_feature_t = uint32_t; // NOLINT +using bst_feature_t = std::uint32_t; // NOLINT /*! \brief Type for histogram bin index. */ -using bst_bin_t = int32_t; // NOLINT +using bst_bin_t = std::int32_t; // NOLINT /*! \brief Type for data row index. * * Be careful `std::size_t' is implementation-defined. Meaning that the binary @@ -131,11 +132,11 @@ using bst_bin_t = int32_t; // NOLINT */ using bst_row_t = std::size_t; // NOLINT /*! \brief Type for tree node index. */ -using bst_node_t = int32_t; // NOLINT +using bst_node_t = std::int32_t; // NOLINT /*! \brief Type for ranking group index. */ -using bst_group_t = uint32_t; // NOLINT -/*! \brief Type for indexing target variables. */ -using bst_target_t = std::size_t; // NOLINT +using bst_group_t = std::uint32_t; // NOLINT +/*! \brief Type for indexing into output targets. */ +using bst_target_t = std::uint32_t; // NOLINT namespace detail { /*! \brief Implementation of gradient statistics pair. Template specialisation @@ -171,11 +172,14 @@ class GradientPairInternal { } // Copy constructor if of same value type, marked as default to be trivially_copyable - GradientPairInternal(const GradientPairInternal &g) = default; + GradientPairInternal(GradientPairInternal const &g) = default; + GradientPairInternal(GradientPairInternal &&g) = default; + GradientPairInternal &operator=(GradientPairInternal const &that) = default; + GradientPairInternal &operator=(GradientPairInternal &&that) = default; // Copy constructor if different value type - use getters and setters to // perform conversion - template + template ::value>* = nullptr> XGBOOST_DEVICE explicit GradientPairInternal(const GradientPairInternal &g) { SetGrad(g.GetGrad()); SetHess(g.GetHess()); diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 35d3cf586d6a..551946ec186a 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -162,6 +163,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { */ virtual int32_t BoostedRounds() const = 0; virtual uint32_t Groups() const = 0; + virtual bst_target_t Targets() const = 0; void LoadModel(Json const& in) override = 0; void SaveModel(Json* out) const override = 0; @@ -305,11 +307,21 @@ struct LearnerModelParam { linalg::Tensor base_score_; public: - /* \brief number of features */ - uint32_t num_feature { 0 }; - /* \brief number of classes, if it is multi-class classification */ - uint32_t num_output_group { 0 }; - /* \brief Current task, determined by objective. */ + /** + * \brief The number of features. + */ + bst_feature_t num_feature{0}; + /** + * \brief The number of classes or targets if the current strategy is composite. + */ + uint32_t num_output_group{0}; + /** + * \brief The number of output targets. + */ + bst_target_t num_target{0}; + /** + * \brief Current task, determined by objective. + */ ObjInfo task{ObjInfo::kRegression}; LearnerModelParam() = default; @@ -319,13 +331,20 @@ struct LearnerModelParam { linalg::Tensor base_margin, ObjInfo t); LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t); LearnerModelParam(bst_feature_t n_features, linalg::Tensor base_margin, - uint32_t n_groups) - : base_score_{std::move(base_margin)}, num_feature{n_features}, num_output_group{n_groups} {} + uint32_t n_groups, bst_target_t n_targets) + : base_score_{std::move(base_margin)}, + num_feature{n_features}, + num_output_group{n_groups}, + num_target{n_targets} {} linalg::TensorView BaseScore(Context const* ctx) const; linalg::TensorView BaseScore(int32_t device) const; void Copy(LearnerModelParam const& that); + bool IsVectorLeaf() const { return num_output_group == 1 && num_target > 1; } + bst_target_t OutputLength() const { + return this->IsVectorLeaf() ? this->num_target : this->num_output_group; + } /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ bool Initialized() const { return num_feature != 0 && num_output_group != 0; } diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index ca816bcdb7a4..f1d3b7472f86 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -16,6 +16,7 @@ #include #include #include // std::int32_t +#include // std::size_t #include #include #include @@ -153,14 +154,34 @@ inline LINALG_HD int Popc(uint64_t v) { #endif // compiler } +template +LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head) { + static_assert(std::is_integral>::value, "Invalid index type."); + arr[D - 1] = head; +} + +/** + * \brief Convert index from parameter pack to C-style array. + */ +template +LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head, Rest &&...index) { + static_assert(sizeof...(Rest) < D, "Index overflow."); + static_assert(std::is_integral>::value, "Invalid index type."); + arr[D - sizeof...(Rest) - 1] = head; + IndexToArr(arr, std::forward(index)...); +} + template -constexpr auto Arr2Tup(T (&arr)[N], std::index_sequence) { +constexpr auto ArrToTuple(T (&arr)[N], std::index_sequence) { return std::make_tuple(arr[Idx]...); } +/** + * \brief Convert C-styple array to std::tuple. + */ template -constexpr auto Arr2Tup(T (&arr)[N]) { - return Arr2Tup(arr, std::make_index_sequence{}); +constexpr auto ArrToTuple(T (&arr)[N]) { + return ArrToTuple(arr, std::make_index_sequence{}); } // uint division optimization inspired by the CIndexer in cupy. Division operation is @@ -183,7 +204,7 @@ LINALG_HD auto UnravelImpl(I idx, common::Span shape) { } } index[0] = idx; - return Arr2Tup(index); + return ArrToTuple(index); } template @@ -246,6 +267,11 @@ constexpr detail::RangeTag Range(I beg, I end) { return {beg, end}; } +enum class Order : std::uint8_t { + kC, // Row major + kF, // Col major +}; + /** * \brief A tensor view with static type and dimension. It implements indexing and slicing. * @@ -371,7 +397,11 @@ class TensorView { * \param device Device ordinal */ template - LINALG_HD TensorView(common::Span data, I const (&shape)[D], int32_t device) + LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device) + : TensorView{data, shape, Order::kC, device} {} + + template + LINALG_HD TensorView(common::Span data, I const (&shape)[D], Order order, std::int32_t device) : data_{data}, ptr_{data_.data()}, device_{device} { static_assert(D > 0 && D <= kDim, "Invalid shape."); // shape @@ -380,7 +410,19 @@ class TensorView { shape_[i] = 1; } // stride - detail::CalcStride(shape_, stride_); + switch (order) { + case Order::kC: { + detail::CalcStride(shape_, stride_); + break; + } + case Order::kF: { + detail::CalcStride(shape_, stride_); + break; + } + default: { + SPAN_CHECK(false); + } + } // size this->CalcSize(); } @@ -524,16 +566,20 @@ class TensorView { /** * \brief Constructor for automatic type deduction. */ -template ::value> * = nullptr> -auto MakeTensorView(Container &data, I const (&shape)[D], int32_t device) { // NOLINT +auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT using T = typename Container::value_type; - return TensorView{data, shape, device}; + std::size_t in_shape[sizeof...(S)]; + detail::IndexToArr(in_shape, std::forward(shape)...); + return TensorView{data, in_shape, ctx->gpu_id}; } -template -LINALG_HD auto MakeTensorView(common::Span data, I const (&shape)[D], int32_t device) { - return TensorView{data, shape, device}; +template +LINALG_HD auto MakeTensorView(Context const *ctx, common::Span data, S &&...shape) { + std::size_t in_shape[sizeof...(S)]; + detail::IndexToArr(in_shape, std::forward(shape)...); + return TensorView{data, in_shape, ctx->gpu_id}; } /** @@ -665,6 +711,7 @@ class Tensor { private: HostDeviceVector data_; ShapeT shape_{0}; + Order order_{Order::kC}; template void Initialize(I const (&shape)[D], std::int32_t device) { @@ -690,11 +737,12 @@ class Tensor { * See \ref TensorView for parameters of this constructor. */ template - explicit Tensor(I const (&shape)[D], int32_t device) - : Tensor{common::Span{shape}, device} {} + explicit Tensor(I const (&shape)[D], Order order, std::int32_t device) + : Tensor{common::Span{shape}, order, device} {} template - explicit Tensor(common::Span shape, int32_t device) { + explicit Tensor(common::Span shape, Order order, std::int32_t device) + : order_{order} { // No device unroll as this is a host only function. std::copy(shape.data(), shape.data() + D, shape_); for (auto i = D; i < kDim; ++i) { @@ -713,7 +761,8 @@ class Tensor { * Initialize from 2 host iterators. */ template - explicit Tensor(It begin, It end, I const (&shape)[D], int32_t device) { + explicit Tensor(It begin, It end, I const (&shape)[D], Order order, std::int32_t device) + : order_{order} { auto &h_vec = data_.HostVector(); h_vec.insert(h_vec.begin(), begin, end); // shape @@ -721,8 +770,9 @@ class Tensor { } template - explicit Tensor(std::initializer_list data, I const (&shape)[D], - int32_t device = Context::kCpuId) { + explicit Tensor(std::initializer_list data, I const (&shape)[D], Order order, + std::int32_t device) + : order_{order} { auto &h_vec = data_.HostVector(); h_vec = data; // shape @@ -752,20 +802,20 @@ class Tensor { if (device >= 0) { data_.SetDevice(device); auto span = data_.DeviceSpan(); - return {span, shape_, device}; + return {span, shape_, order_, device}; } else { auto span = data_.HostSpan(); - return {span, shape_, device}; + return {span, shape_, order_, device}; } } TensorView View(int32_t device) const { if (device >= 0) { data_.SetDevice(device); auto span = data_.ConstDeviceSpan(); - return {span, shape_, device}; + return {span, shape_, order_, device}; } else { auto span = data_.ConstHostSpan(); - return {span, shape_, device}; + return {span, shape_, order_, device}; } } @@ -826,6 +876,20 @@ class Tensor { void Reshape(size_t (&shape)[D]) { this->Reshape(common::Span{shape}); } + /** + * \brief Get a host view on the slice. + */ + template + auto Slice(S &&...slices) const { + return this->HostView().Slice(std::forward(slices)...); + } + /** + * \brief Get a host view on the slice. + */ + template + auto Slice(S &&...slices) { + return this->HostView().Slice(std::forward(slices)...); + } /** * \brief Set device ordinal for this tensor. @@ -834,9 +898,26 @@ class Tensor { int32_t DeviceIdx() const { return data_.DeviceIdx(); } }; +template +using Matrix = Tensor; + template using Vector = Tensor; +/** + * \brief Create an array without initialization. + */ +template +auto Empty(Context const *ctx, Index &&...index) { + Tensor t; + t.SetDevice(ctx->gpu_id); + t.Reshape(index...); + return t; +} + +/** + * \brief Create an array with value v. + */ template auto Constant(Context const *ctx, T v, Index &&...index) { Tensor t; @@ -846,7 +927,6 @@ auto Constant(Context const *ctx, T v, Index &&...index) { return t; } - /** * \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros. */ diff --git a/include/xgboost/multi_target_tree_model.h b/include/xgboost/multi_target_tree_model.h new file mode 100644 index 000000000000..d0831898d867 --- /dev/null +++ b/include/xgboost/multi_target_tree_model.h @@ -0,0 +1,125 @@ +/** + * Copyright 2022 by XGBoost contributors + * + * \brief Core data structure for vector leaf + */ +#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_ +#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_ +#include +#include // FeatureType +#include // VectorView +#include // Model +#include // Span +#include // StringView + +#include // std::uint8_t +#include +#include // std::vector + +#include "xgboost/string_view.h" + +namespace xgboost { +/** + * \brief CSR-like matrix for categorical splits. + * + * The fields of split_categories_segments_[i] are set such that the range + * node_ptr[beg:(beg+size)] stores the bitset for the matching categories for the + * i-th node. + */ +struct CategoricalSplitMatrix { + struct Segment { + std::size_t beg{0}; + std::size_t size{0}; + }; + common::Span split_type; + common::Span categories; + common::Span node_ptr; +}; + +/** + * \brief Tree structure for multi-target model. + */ +class MultiTargetTree : public Model { + private: + static bst_node_t constexpr InvalidNodeId() { return -1; } + + private: + bst_target_t n_targets_; + std::vector left_; + std::vector right_; + std::vector parent_; + std::vector split_index_; + std::vector default_left_; + std::vector split_conds_; + std::vector weights_; + + bst_node_t n_nodes_{1}; + bst_feature_t n_features_; + + public: + explicit MultiTargetTree(bst_target_t n_targets, bst_feature_t n_features) + : n_targets_{n_targets}, + left_(1ul, InvalidNodeId()), + right_(1ul, InvalidNodeId()), + parent_(1ul, InvalidNodeId()), + split_index_(1ul, 0), + default_left_(1ul, 0), + split_conds_(1ul, std::numeric_limits::quiet_NaN()), + weights_(1ul, std::numeric_limits::quiet_NaN()), + n_features_{n_features} { + CHECK_GE(n_targets_, 1); + } + + void SetLeaf(bst_node_t nidx, linalg::VectorView weight) { + CHECK_EQ(nidx, 0); + auto to = nidx + 1; + CHECK_EQ(weight.Size(), n_targets_); + CHECK_EQ(n_nodes_, 1); + weights_.resize(to * weight.Size()); + for (size_t i = 0; i < weight.Size(); ++i) { + weights_[i] = weight(i); + } + } + + void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, + linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight); + + bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); } + bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); } + bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); } + bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); } + + bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; } + float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; } + bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; } + bst_node_t DefaultChild(bst_node_t nidx) const { + return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx); + } + + bst_target_t NumTargets() const { return n_targets_; } + + size_t Size() const { return n_nodes_; } + + bst_node_t Depth(bst_node_t nidx) const { + bst_node_t depth{0}; + while (Parent(nidx) != InvalidNodeId()) { + ++depth; + nidx = Parent(nidx); + } + return depth; + } + + linalg::VectorView LeafValue(bst_node_t nidx) const { + CHECK(IsLeaf(nidx)); + auto beg = nidx * n_targets_; + auto v = common::Span{weights_}.subspan(beg, n_targets_); + return linalg::MakeVec(v.data(), v.size()); + } + + void LoadModel(Json const& in) override; + void SaveModel(Json* out) const override; +}; +} // namespace xgboost +#endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_ diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 70c71cac1ad9..36d1a08ba40e 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,7 @@ #include #include #include +#include "xgboost/linalg.h" namespace xgboost { @@ -124,6 +126,34 @@ struct RTreeNodeStat { } }; +/** + * \brief Helper for defining copyable data structure that contains unique pointers. + */ +template +class CopyUniquePtr { + std::unique_ptr ptr_{nullptr}; + + public: + CopyUniquePtr() = default; + CopyUniquePtr(CopyUniquePtr const& that) { + ptr_.reset(nullptr); + if (that.ptr_) { + ptr_ = std::make_unique(*that); + } + } + T* get() const noexcept { return ptr_.get(); } // NOLINT + + T& operator*() { return *ptr_; } + T* operator->() noexcept { return this->get(); } + + T const& operator*() const { return *ptr_; } + T const* operator->() const noexcept { return this->get(); } + + explicit operator bool() const { return static_cast(ptr_); } + bool operator!() const { return !ptr_; } + void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT +}; + /*! * \brief define regression tree to be the most common tree model. * This is the data structure used in xgboost's major tree models. @@ -312,7 +342,6 @@ class RegTree : public Model { /*! \brief model parameter */ TreeParam param; - /*! \brief constructor */ RegTree() { param.num_nodes = 1; param.num_deleted = 0; @@ -320,11 +349,19 @@ class RegTree : public Model { stats_.resize(param.num_nodes); split_types_.resize(param.num_nodes, FeatureType::kNumerical); split_categories_segments_.resize(param.num_nodes); - for (int i = 0; i < param.num_nodes; i ++) { + for (int i = 0; i < param.num_nodes; i++) { nodes_[i].SetLeaf(0.0f); nodes_[i].SetParent(kInvalidNodeId); } } + /*! \brief constructor */ + explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree() { + param.num_feature = n_features; + if (n_targets > 1) { + this->p_mt_tree_.reset(new MultiTargetTree{n_targets, n_features}); + } + } + /*! \brief get node given nid */ Node& operator[](int nid) { return nodes_[nid]; @@ -424,6 +461,11 @@ class RegTree : public Model { float right_sum, bst_node_t leaf_right_child = kInvalidNodeId); + void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, + linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight); + /** * \brief Expands a leaf node with categories * @@ -445,15 +487,38 @@ class RegTree : public Model { bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum); + void ExpandCategorical(bst_node_t nidx, bst_feature_t split_index, + common::Span split_cat, bool default_left, + linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight); + bool HasCategoricalSplit() const { return !split_categories_.empty(); } + /** + * \brief Whether this is a multi-target tree. + */ + bool IsMultiTarget() const { return static_cast(p_mt_tree_); } + bst_target_t NumTargets() const { + if (IsMultiTarget()) { + return this->p_mt_tree_->NumTargets(); + } + return 1; + } + auto GetMultiTargetTree() const { + CHECK(IsMultiTarget()); + return p_mt_tree_.get(); + } /*! * \brief get current depth * \param nid node id */ int GetDepth(int nid) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->Depth(nid); + } int depth = 0; while (!nodes_[nid].IsRoot()) { ++depth; @@ -461,6 +526,10 @@ class RegTree : public Model { } return depth; } + void SetLeaf(bst_node_t nidx, linalg::VectorView weight) { + CHECK(IsMultiTarget()); + return this->p_mt_tree_->SetLeaf(nidx, weight); + } /*! * \brief get maximum depth @@ -571,7 +640,9 @@ class RegTree : public Model { /*! * \brief Get split types for all nodes. */ - std::vector const &GetSplitTypes() const { return split_types_; } + std::vector const& GetSplitTypes() const { + return split_types_; + } common::Span GetSplitCategories() const { return split_categories_; } /*! * \brief Get the bit storage for categories @@ -585,28 +656,71 @@ class RegTree : public Model { } auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; } - // The fields of split_categories_segments_[i] are set such that - // the range split_categories_[beg:(beg+size)] stores the bitset for - // the matching categories for the i-th node. - struct Segment { - size_t beg {0}; - size_t size {0}; - }; - - struct CategoricalSplitMatrix { - common::Span split_type; - common::Span categories; - common::Span node_ptr; - }; - CategoricalSplitMatrix GetCategoriesMatrix() const { CategoricalSplitMatrix view; view.split_type = common::Span(this->GetSplitTypes()); view.categories = this->GetSplitCategories(); - view.node_ptr = common::Span(split_categories_segments_); + view.node_ptr = common::Span(split_categories_segments_); return view; } + bst_feature_t SplitIndex(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->SplitIndex(nidx); + } + return (*this)[nidx].SplitIndex(); + } + float SplitCond(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->SplitCond(nidx); + } + return (*this)[nidx].SplitCond(); + } + bool DefaultLeft(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->DefaultLeft(nidx); + } + return (*this)[nidx].DefaultLeft(); + } + bool IsRoot(bst_node_t nidx) const { + if (IsMultiTarget()) { + return nidx == kRoot; + } + return (*this)[nidx].IsRoot(); + } + bst_node_t Parent(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->Parent(nidx); + } + return (*this)[nidx].Parent(); + } + bst_node_t LeftChild(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->LeftChild(nidx); + } + return (*this)[nidx].LeftChild(); + } + bst_node_t RightChild(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->RightChild(nidx); + } + return (*this)[nidx].RightChild(); + } + bool IsLeftChild(bst_node_t nidx) const { + if (IsMultiTarget()) { + CHECK_NE(nidx, kRoot); + auto p = this->p_mt_tree_->Parent(nidx); + return nidx == this->p_mt_tree_->LeftChild(p); + } + return (*this)[nidx].IsLeftChild(); + } + bst_node_t Size() const { + if (IsMultiTarget()) { + return this->p_mt_tree_->Size(); + } + return this->nodes_.size(); + } + private: template void LoadCategoricalSplit(Json const& in); @@ -622,8 +736,9 @@ class RegTree : public Model { // Categories for each internal node. std::vector split_categories_; // Ptr to split categories of each node. - std::vector split_categories_segments_; - + std::vector split_categories_segments_; + // ptr to multi-target tree with vector leaf. + CopyUniquePtr p_mt_tree_; // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize bst_node_t AllocNode() { @@ -703,5 +818,10 @@ inline bool RegTree::FVec::IsMissing(size_t i) const { inline bool RegTree::FVec::HasMissing() const { return has_missing_; } + +// Multi-target tree not yet implemented error +inline StringView MTNotImplemented() { + return " support for multi-target tree is not yet implemented."; +} } // namespace xgboost #endif // XGBOOST_TREE_MODEL_H_ diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 5cf8fb05c10a..bb903b0a4040 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index cb85590ebb4c..09dc38c12388 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -970,9 +970,8 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, xgboost_CHECK_C_ARG_PTR(out_dim); xgboost_CHECK_C_ARG_PTR(out_shape); - CalcPredictShape(strict_shape, type, p_m->Info().num_row_, - p_m->Info().num_col_, chunksize, learner->Groups(), rounds, - &shape, out_dim); + CalcPredictShape(strict_shape, type, p_m->Info().num_row_, p_m->Info().num_col_, chunksize, + learner->Groups(), learner->Targets(), rounds, &shape, out_dim); *out_shape = dmlc::BeginPtr(shape); API_END(); } @@ -1000,7 +999,7 @@ void InplacePredictImpl(std::shared_ptr p_m, char const *c_json_config, xgboost_CHECK_C_ARG_PTR(out_dim); CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(), - learner->BoostedRounds(), &shape, out_dim); + learner->Targets(), learner->BoostedRounds(), &shape, out_dim); xgboost_CHECK_C_ARG_PTR(out_result); xgboost_CHECK_C_ARG_PTR(out_shape); diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 2af36f0acc98..fe87aa90b66c 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022 by Contributors +// Copyright (c) 2019-2023 by Contributors #include "../common/threading_utils.h" #include "../data/device_adapter.cuh" #include "../data/proxy_dmatrix.h" @@ -129,7 +129,8 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface, xgboost_CHECK_C_ARG_PTR(out_dim); CalcPredictShape(strict_shape, type, n_samples, p_m->Info().num_col_, chunksize, - learner->Groups(), learner->BoostedRounds(), &shape, out_dim); + learner->Groups(), learner->Targets(), learner->BoostedRounds(), &shape, + out_dim); *out_shape = dmlc::BeginPtr(shape); *out_result = p_predt->ConstDevicePointer(); API_END(); diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 17c88e8c54b9..283a599812f1 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -32,10 +32,14 @@ namespace xgboost { * \param out_dim Output dimension */ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows, size_t cols, - size_t chunksize, size_t groups, size_t rounds, - std::vector *out_shape, + size_t chunksize, std::size_t groups, bst_target_t n_targets, + size_t rounds, std::vector *out_shape, xgboost::bst_ulong *out_dim) { auto &shape = *out_shape; + if (n_targets > 1) { + CHECK_EQ(groups, 1); + groups = n_targets; + } if (type == PredictionType::kMargin && rows != 0) { // When kValue is used, softmax can change the chunksize. CHECK_EQ(chunksize, groups); diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 712b76eff4da..19004afb7ea9 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -119,7 +119,7 @@ class RabitCommunicator : public Communicator { } template ::value> * = nullptr> - void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { + void DoBitwiseAllReduce(void *, std::size_t, Operation) { LOG(FATAL) << "Floating point types do not support bitwise operations."; } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index c2e1506d4925..af347eb27106 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -193,9 +193,9 @@ class GHistBuildingManager { }; template -void RowsWiseBuildHistKernel(const std::vector &gpair, - const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist) { +void RowsWiseBuildHistKernel(Span gpair, + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRow hist) { constexpr bool kAnyMissing = BuildingManager::kAnyMissing; constexpr bool kFirstPage = BuildingManager::kFirstPage; using BinIdxType = typename BuildingManager::BinIdxType; @@ -262,9 +262,9 @@ void RowsWiseBuildHistKernel(const std::vector &gpair, } template -void ColsWiseBuildHistKernel(const std::vector &gpair, - const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist) { +void ColsWiseBuildHistKernel(Span gpair, + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRow hist) { constexpr bool kAnyMissing = BuildingManager::kAnyMissing; constexpr bool kFirstPage = BuildingManager::kFirstPage; using BinIdxType = typename BuildingManager::BinIdxType; @@ -315,9 +315,8 @@ void ColsWiseBuildHistKernel(const std::vector &gpair, } template -void BuildHistDispatch(const std::vector &gpair, - const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist) { +void BuildHistDispatch(Span gpair, const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, GHistRow hist) { if (BuildingManager::kReadByColumn) { ColsWiseBuildHistKernel(gpair, row_indices, gmat, hist); } else { @@ -344,33 +343,31 @@ void BuildHistDispatch(const std::vector &gpair, } template -void GHistBuilder::BuildHist(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, +void GHistBuilder::BuildHist(Span gpair, + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, GHistRow hist, bool force_read_by_column) const { /* force_read_by_column is used for testing the columnwise building of histograms. * default force_read_by_column = false */ constexpr double kAdhocL2Size = 1024 * 1024 * 0.8; - const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back(); + const bool hist_fit_to_l2 = kAdhocL2Size > 2 * sizeof(float) * gmat.cut.Ptrs().back(); bool first_page = gmat.base_rowid == 0; bool read_by_column = !hist_fit_to_l2 && !any_missing; auto bin_type_size = gmat.index.GetBinTypeSize(); GHistBuildingManager::DispatchAndExecute( - {first_page, read_by_column || force_read_by_column, bin_type_size}, - [&](auto t) { - using BuildingManager = decltype(t); - BuildHistDispatch(gpair, row_indices, gmat, hist); - }); + {first_page, read_by_column || force_read_by_column, bin_type_size}, [&](auto t) { + using BuildingManager = decltype(t); + BuildHistDispatch(gpair, row_indices, gmat, hist); + }); } -template void GHistBuilder::BuildHist(const std::vector &gpair, +template void GHistBuilder::BuildHist(Span gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, GHistRow hist, bool force_read_by_column) const; -template void GHistBuilder::BuildHist(const std::vector &gpair, +template void GHistBuilder::BuildHist(Span gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, GHistRow hist, bool force_read_by_column) const; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 62d29f5315e0..ee54d2dce029 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -23,6 +23,7 @@ #include "row_set.h" #include "threading_utils.h" #include "timer.h" +#include "xgboost/base.h" // bst_feature_t, bst_bin_t namespace xgboost { class GHistIndexMatrix; @@ -320,10 +321,10 @@ struct Index { }; template -bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, +bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t end, GradientIndex const& data, - uint32_t const fidx_begin, - uint32_t const fidx_end) { + bst_feature_t const fidx_begin, + bst_feature_t const fidx_end) { size_t previous_middle = std::numeric_limits::max(); while (end != begin) { size_t middle = begin + (end - begin) / 2; @@ -635,7 +636,7 @@ class GHistBuilder { // construct a histogram via histogram aggregation template - void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, + void BuildHist(Span gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, GHistRow hist, bool force_read_by_column = false) const; uint32_t GetNumBins() const { diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index 030070d9aecd..6f43022bd4c4 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -179,7 +179,6 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t -template class HostDeviceVector; #if defined(__APPLE__) || defined(__EMSCRIPTEN__) /* diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index a5c5dbf8fa1b..b2e2ebd0b353 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -412,7 +412,7 @@ template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; -template class HostDeviceVector; +template class HostDeviceVector; template class HostDeviceVector; #if defined(__APPLE__) diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index d52bcef87d54..a156481ddf9b 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -107,98 +107,98 @@ class PartitionBuilder { return {nleft_elems, nright_elems}; } - template - void Partition(const size_t node_in_set, std::vector const &nodes, - const common::Range1d range, - const bst_bin_t split_cond, GHistIndexMatrix const& gmat, - const common::ColumnMatrix& column_matrix, - const RegTree& tree, const size_t* rid) { - common::Span rid_span(rid + range.begin(), rid + range.end()); - common::Span left = GetLeftBuffer(node_in_set, range.begin(), range.end()); - common::Span right = GetRightBuffer(node_in_set, range.begin(), range.end()); - std::size_t nid = nodes[node_in_set].nid; - bst_feature_t fid = tree[nid].SplitIndex(); - bool default_left = tree[nid].DefaultLeft(); - bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; - auto node_cats = tree.NodeCats(nid); - - auto const& index = gmat.index; - auto const& cut_values = gmat.cut.Values(); - auto const& cut_ptrs = gmat.cut.Ptrs(); - - auto gidx_calc = [&](auto ridx) { - auto begin = gmat.RowIdx(ridx); - if (gmat.IsDense()) { - return static_cast(index[begin + fid]); - } - auto end = gmat.RowIdx(ridx + 1); - auto f_begin = cut_ptrs[fid]; - auto f_end = cut_ptrs[fid + 1]; - // bypassing the column matrix as we need the cut value instead of bin idx for categorical - // features. - return BinarySearchBin(begin, end, index, f_begin, f_end); - }; - - auto pred_hist = [&](auto ridx, auto bin_id) { - if (any_cat && is_cat) { - auto gidx = gidx_calc(ridx); - bool go_left = default_left; - if (gidx > -1) { - go_left = Decision(node_cats, cut_values[gidx]); - } - return go_left; - } else { - return bin_id <= split_cond; - } - }; - - auto pred_approx = [&](auto ridx) { - auto gidx = gidx_calc(ridx); - bool go_left = default_left; - if (gidx > -1) { - if (is_cat) { - go_left = Decision(node_cats, cut_values[gidx]); - } else { - go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value; - } - } - return go_left; - }; - - std::pair child_nodes_sizes; - if (!column_matrix.IsInitialized()) { - child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx); - } else { - if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) { - auto column = column_matrix.DenseColumn(fid); - if (default_left) { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred_hist); - } else { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred_hist); - } - } else { - CHECK_EQ(any_missing, true); - auto column = - column_matrix.SparseColumn(fid, rid_span.front() - gmat.base_rowid); - if (default_left) { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred_hist); - } else { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred_hist); - } - } - } - - const size_t n_left = child_nodes_sizes.first; - const size_t n_right = child_nodes_sizes.second; - - SetNLeftElems(node_in_set, range.begin(), n_left); - SetNRightElems(node_in_set, range.begin(), n_right); - } - + + template + void Partition(const size_t node_in_set, std::vector const& nodes, + const common::Range1d range, const bst_bin_t split_cond, + GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, + const RegTree& tree, const size_t* rid) { + common::Span rid_span(rid + range.begin(), rid + range.end()); + common::Span left = GetLeftBuffer(node_in_set, range.begin(), range.end()); + common::Span right = GetRightBuffer(node_in_set, range.begin(), range.end()); + std::size_t nid = nodes[node_in_set].nid; + bst_feature_t fid = tree.SplitIndex(nid); + bool default_left = tree.DefaultLeft(nid); + bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; + auto node_cats = tree.NodeCats(nid); + + auto const& index = gmat.index; + auto const& cut_values = gmat.cut.Values(); + auto const& cut_ptrs = gmat.cut.Ptrs(); + + auto gidx_calc = [&](auto ridx) { + auto begin = gmat.RowIdx(ridx); + if (gmat.IsDense()) { + return static_cast(index[begin + fid]); + } + auto end = gmat.RowIdx(ridx + 1); + auto f_begin = cut_ptrs[fid]; + auto f_end = cut_ptrs[fid + 1]; + // bypassing the column matrix as we need the cut value instead of bin idx for categorical + // features. + return BinarySearchBin(begin, end, index, f_begin, f_end); + }; + + auto pred_hist = [&](auto ridx, auto bin_id) { + if (any_cat && is_cat) { + auto gidx = gidx_calc(ridx); + bool go_left = default_left; + if (gidx > -1) { + go_left = Decision(node_cats, cut_values[gidx]); + } + return go_left; + } else { + return bin_id <= split_cond; + } + }; + + auto pred_approx = [&](auto ridx) { + auto gidx = gidx_calc(ridx); + bool go_left = default_left; + if (gidx > -1) { + if (is_cat) { + go_left = Decision(node_cats, cut_values[gidx]); + } else { + go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value; + } + } + return go_left; + }; + + std::pair child_nodes_sizes; + if (!column_matrix.IsInitialized()) { + child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx); + } else { + if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) { + auto column = column_matrix.DenseColumn(fid); + if (default_left) { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } else { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } + } else { + CHECK_EQ(any_missing, true); + auto column = + column_matrix.SparseColumn(fid, rid_span.front() - gmat.base_rowid); + if (default_left) { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } else { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } + } + } + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + SetNLeftElems(node_in_set, range.begin(), n_left); + SetNRightElems(node_in_set, range.begin(), n_right); + } + // allocate thread local memory, should be called for each specific task void AllocateForTask(size_t id) { if (mem_blocks_[id].get() == nullptr) { diff --git a/src/common/quantile.h b/src/common/quantile.h index 27c528e8e716..e703bae574f9 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -352,19 +352,6 @@ struct WQSummary { prev_rmax = data[i].rmax; } } - // check consistency of the summary - inline bool Check(const char *msg) const { - const float tol = 10.0f; - for (size_t i = 0; i < this->size; ++i) { - if (data[i].rmin + data[i].wmin > data[i].rmax + tol || - data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) { - LOG(INFO) << "---------- WQSummary::Check did not pass ----------"; - this->Print(); - return false; - } - } - return true; - } }; /*! \brief try to do efficient pruning */ diff --git a/src/common/stats.cu b/src/common/stats.cu index 2e728a8bc333..76aec7dd18f1 100644 --- a/src/common/stats.cu +++ b/src/common/stats.cu @@ -1,13 +1,14 @@ /*! - * Copyright 2022 by XGBoost Contributors + * Copyright 2022-2023 by XGBoost Contributors */ #include // thrust::make_counting_iterator -#include "common.h" // common::OptionalWeights -#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend -#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile -#include "xgboost/context.h" // Context +#include "common.h" // common::OptionalWeights +#include "cuda_context.cuh" // CUDAContext +#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend +#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile +#include "xgboost/context.h" // Context #include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/linalg.h" // linalg::TensorView, UnravelIndex, Apply @@ -49,9 +50,11 @@ void Mean(Context const* ctx, linalg::VectorView v, linalg::VectorV thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v(i) / n; }); std::size_t bytes; CHECK_EQ(out.Size(), 1); - cub::DeviceReduce::Sum(nullptr, bytes, it, out.Values().data(), v.Size()); + cub::DeviceReduce::Sum(nullptr, bytes, it, out.Values().data(), v.Size(), + ctx->CUDACtx()->Stream()); dh::TemporaryArray temp{bytes}; - cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size()); + cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size(), + ctx->CUDACtx()->Stream()); } } // namespace cuda_impl } // namespace common diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 656e570ae812..c403e687cf00 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -295,7 +295,7 @@ class MemStackAllocator { /** * \brief Constant that can be used for initializing static thread local memory. */ -std::int32_t constexpr DefaultMaxThreads() { return 128; } +std::int32_t constexpr DefaultMaxThreads() { return 64; } } // namespace common } // namespace xgboost diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 19dd3490d040..1953588ff736 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -102,7 +102,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, return HostAdapterDispatch(proxy, [&](auto const& value) { size_t n_threads = ctx_.Threads(); size_t n_features = column_sizes.size(); - linalg::Tensor column_sizes_tloc({n_threads, n_features}, Context::kCpuId); + linalg::Tensor column_sizes_tloc({n_threads, n_features}, linalg::Order::kC, + Context::kCpuId); auto view = column_sizes_tloc.HostView(); common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) { auto const& line = value.GetLine(i); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 808ecd8b370b..1c5692f8073f 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -181,8 +181,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { - info_.base_margin_ = decltype(info_.base_margin_){ - batch.BaseMargin(), batch.BaseMargin() + batch.Size(), {batch.Size()}, Context::kCpuId}; + info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(), + batch.BaseMargin() + batch.Size(), + {batch.Size()}, + linalg::Order::kC, + Context::kCpuId}; } if (batch.Qid() != nullptr) { qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 9cb63d9165d0..c4a758734240 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -10,6 +10,7 @@ #include #include +#include // std::uint32_t #include #include #include @@ -26,9 +27,12 @@ #include "xgboost/host_device_vector.h" #include "xgboost/json.h" #include "xgboost/logging.h" +#include "xgboost/model.h" +#include "xgboost/multi_target_tree_model.h" #include "xgboost/objective.h" #include "xgboost/predictor.h" -#include "xgboost/string_view.h" +#include "xgboost/string_view.h" // StringView +#include "xgboost/tree_model.h" // RegTree #include "xgboost/tree_updater.h" namespace xgboost { @@ -131,6 +135,11 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) { // set, since only experts are expected to do so. return; } + if (model_.learner_model_param->IsVectorLeaf()) { + CHECK(tparam_.tree_method == TreeMethod::kHist) + << "Only hist tree method is supported for multi-target with vector leaf."; + } + // tparam_ is set before calling this function. if (tparam_.tree_method != TreeMethod::kAuto) { return; @@ -175,12 +184,12 @@ void GBTree::ConfigureUpdaters() { case TreeMethod::kExact: tparam_.updater_seq = "grow_colmaker,prune"; break; - case TreeMethod::kHist: - LOG(INFO) << - "Tree method is selected to be 'hist', which uses a " - "single updater grow_quantile_histmaker."; + case TreeMethod::kHist: { + LOG(INFO) << "Tree method is selected to be 'hist', which uses a single updater " + "grow_quantile_histmaker."; tparam_.updater_seq = "grow_quantile_histmaker"; break; + } case TreeMethod::kGPUHist: { common::AssertGPUSupport(); tparam_.updater_seq = "grow_gpu_hist"; @@ -209,11 +218,9 @@ void CopyGradient(HostDeviceVector const* in_gpair, int32_t n_thre GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair); } else { std::vector &tmp_h = out_gpair->HostVector(); - auto nsize = static_cast(out_gpair->Size()); - const auto &gpair_h = in_gpair->ConstHostVector(); - common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { - tmp_h[i] = gpair_h[i * n_groups + group_id]; - }); + const auto& gpair_h = in_gpair->ConstHostVector(); + common::ParallelFor(out_gpair->Size(), n_threads, + [&](auto i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; }); } } @@ -240,6 +247,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, const int ngroup = model_.learner_model_param->num_output_group; ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); + // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let // `gpu_id` be the single source of determining what algorithms to run, but that will // break a lots of existing code. @@ -254,7 +262,13 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, LOG(FATAL) << "Current objective doesn't support external memory."; } - if (ngroup == 1) { + if (model_.learner_model_param->IsVectorLeaf()) { + std::vector> ret; + BoostNewTrees(in_gpair, p_fmat, 0, &ret); + UpdateTreeLeaf(p_fmat, predt->predictions, obj, &ret); + // No update prediction cache yet. + new_trees.push_back(std::move(ret)); + } else if (ngroup == 1) { std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, &ret); @@ -346,7 +360,8 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fma << "Set `process_type` to `update` if you want to update existing " "trees."; // create new tree - std::unique_ptr ptr(new RegTree()); + std::unique_ptr ptr(new RegTree{this->model_.learner_model_param->num_target, + this->model_.learner_model_param->num_feature}); ptr->param.UpdateAllowUnknown(this->cfg_); new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); @@ -368,9 +383,14 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fma } } // update the trees - CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_) - << "Mismatching size between number of rows from input data and size of " - "gradient vector."; + if (model_.learner_model_param->IsVectorLeaf()) { + CHECK_EQ(gpair->Size(), model_.learner_model_param->num_target * p_fmat->Info().num_row_); + } else { + CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_) + << "Mismatching size between number of rows from input data and size of " + "gradient vector."; + } + node_position_.resize(new_trees.size()); for (auto& up : updaters_) { up->Update(gpair, p_fmat, common::Span>{node_position_}, @@ -539,11 +559,10 @@ void GBTree::PredictBatch(DMatrix* p_fmat, if (out_preds->version == 0) { // out_preds->Size() can be non-zero as it's initialized here before any // tree is built at the 0^th iterator. - predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, - model_); + predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, model_); } - uint32_t tree_begin, tree_end; + std::uint32_t tree_begin, tree_end; std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; if (tree_end > tree_begin) { @@ -552,7 +571,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat, if (reset) { out_preds->version = 0; } else { - uint32_t delta = layer_end - out_preds->version; + std::uint32_t delta = layer_end - out_preds->version; out_preds->Update(delta); } } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 38dcb25ead60..5efd3f1e6fb1 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -145,14 +146,22 @@ struct DartTrainParam : public XGBoostParameter { namespace detail { // From here on, layer becomes concrete trees. -inline std::pair LayerToTree(gbm::GBTreeModel const &model, - size_t layer_begin, - size_t layer_end) { - bst_group_t groups = model.learner_model_param->num_output_group; - uint32_t tree_begin = layer_begin * groups * model.param.num_parallel_tree; - uint32_t tree_end = layer_end * groups * model.param.num_parallel_tree; +inline std::pair LayerToTree(gbm::GBTreeModel const& model, + std::uint32_t layer_begin, + std::uint32_t layer_end) { + std::uint32_t tree_begin; + std::uint32_t tree_end; + if (model.learner_model_param->IsVectorLeaf()) { + tree_begin = layer_begin * model.param.num_parallel_tree; + tree_end = layer_end * model.param.num_parallel_tree; + } else { + bst_group_t groups = model.learner_model_param->num_output_group; + tree_begin = layer_begin * groups * model.param.num_parallel_tree; + tree_end = layer_end * groups * model.param.num_parallel_tree; + } + if (tree_end == 0) { - tree_end = static_cast(model.trees.size()); + tree_end = model.trees.size(); } if (model.trees.size() != 0) { CHECK_LE(tree_begin, tree_end); @@ -249,7 +258,10 @@ class GBTree : public GradientBooster { int32_t BoostedRounds() const override { CHECK_NE(model_.param.num_parallel_tree, 0); CHECK_NE(model_.learner_model_param->num_output_group, 0); - return model_.trees.size() / this->LayerTrees(); + + return this->model_.learner_model_param->IsVectorLeaf() + ? (model_.trees.size() / model_.param.num_parallel_tree) + : (model_.trees.size() / this->LayerTrees()); } bool ModelFitted() const override { diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 4e9cc6655eaa..cc0a3aba183e 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -99,7 +99,9 @@ void GBTreeModel::LoadModel(Json const& in) { CHECK(ctx_); common::ParallelFor(trees_json.size(), ctx_->Threads(), [&](auto t) { auto tree_id = get(trees_json[t]["id"]); - trees.at(tree_id).reset(new RegTree()); + CHECK(this->learner_model_param->Initialized()); + trees.at(tree_id).reset( + new RegTree{this->learner_model_param->num_target, this->learner_model_param->num_feature}); trees.at(tree_id)->LoadModel(trees_json[t]); }); diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 1f2bdfa639e1..f2e5dddae10c 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -7,9 +7,9 @@ #include #include -#include -#include -#include +#include // Context +#include // LearnerModelParam +#include // Model #include #include @@ -19,6 +19,7 @@ #include #include "../common/threading_utils.h" +#include "xgboost/multi_target_tree_model.h" namespace xgboost { diff --git a/src/learner.cc b/src/learner.cc index 2462aec2397f..7275af67eab1 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -14,6 +14,7 @@ #include #include #include +#include #include #include // std::numeric_limits #include @@ -21,6 +22,7 @@ #include #include #include +#include // std::underlying_type_t #include #include @@ -47,12 +49,22 @@ #include "xgboost/objective.h" #include "xgboost/parameter.h" #include "xgboost/predictor.h" +#include "xgboost/string_view.h" namespace { - const char* kMaxDeltaStepDefaultValue = "0.7"; } // anonymous namespace +namespace xgboost { +enum class Strategy : std::int32_t { + kComposite = 0, + kMono = 1, +}; + +std::string StrategyStr(Strategy s) { return s == Strategy::kComposite ? "compo" : "mono"; } +} // namespace xgboost +DECLARE_FIELD_ENUM_CLASS(xgboost::Strategy); + namespace xgboost { Learner::~Learner() = default; namespace { @@ -85,8 +97,8 @@ struct LearnerModelParamLegacy : public dmlc::Parameter /*! \brief the version of XGBoost. */ std::uint32_t major_version; std::uint32_t minor_version; - - uint32_t num_target{1}; + /*! \brief Number of output targets, 1 if the strategy is composite. */ + Strategy multi_strategy; /** * \brief Whether we should calculate the base score from training data. * @@ -103,7 +115,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter LearnerModelParamLegacy() { std::memset(this, 0, sizeof(LearnerModelParamLegacy)); base_score = ObjFunction::DefaultBaseScore(); - num_target = 1; + multi_strategy = Strategy::kComposite; major_version = std::get<0>(Version::Self()); minor_version = std::get<1>(Version::Self()); boost_from_average = true; @@ -130,10 +142,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter CHECK(ret.ec == std::errc()); obj["num_class"] = std::string{integers, static_cast(std::distance(integers, ret.ptr))}; - ret = to_chars(integers, integers + NumericLimits::kToCharsSize, - static_cast(num_target)); - obj["num_target"] = - std::string{integers, static_cast(std::distance(integers, ret.ptr))}; + obj["multi_strategy"] = StrategyStr(multi_strategy); ret = to_chars(integers, integers + NumericLimits::kToCharsSize, static_cast(boost_from_average)); @@ -147,9 +156,10 @@ struct LearnerModelParamLegacy : public dmlc::Parameter std::map m; m["num_feature"] = get(j_param.at("num_feature")); m["num_class"] = get(j_param.at("num_class")); - auto n_targets_it = j_param.find("num_target"); - if (n_targets_it != j_param.cend()) { - m["num_target"] = get(n_targets_it->second); + + auto strategy_it = j_param.find("multi_strategy"); + if (strategy_it != j_param.cend()) { + m["multi_strategy"] = get(strategy_it->second); } auto bse_it = j_param.find("boost_from_average"); if (bse_it != j_param.cend()) { @@ -171,7 +181,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter dmlc::ByteSwap(&x.contain_eval_metrics, sizeof(x.contain_eval_metrics), 1); dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1); dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1); - dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1); + // dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1); dmlc::ByteSwap(&x.boost_from_average, sizeof(x.boost_from_average), 1); dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0])); return x; @@ -199,15 +209,16 @@ struct LearnerModelParamLegacy : public dmlc::Parameter DMLC_DECLARE_FIELD(num_feature) .set_default(0) .describe( - "Number of features in training data," - " this parameter will be automatically detected by learner."); + "Number of features in training data, this parameter will be automatically detected by " + "learner."); DMLC_DECLARE_FIELD(num_class).set_default(0).set_lower_bound(0).describe( "Number of class option for multi-class classifier. " " By default equals 0 and corresponds to binary classifier."); - DMLC_DECLARE_FIELD(num_target) - .set_default(1) - .set_lower_bound(1) - .describe("Number of target for multi-target regression."); + DMLC_DECLARE_FIELD(multi_strategy) + .add_enum("compo", Strategy::kComposite) + .add_enum("mono", Strategy::kMono) + .set_default(Strategy::kComposite) + .describe("Strategy used for training multi-target models."); DMLC_DECLARE_FIELD(boost_from_average) .set_default(true) .describe("Whether we should calculate the base score from training data."); @@ -215,15 +226,13 @@ struct LearnerModelParamLegacy : public dmlc::Parameter }; LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t) - : num_feature{user_param.num_feature}, task{t} { - auto n_classes = std::max(static_cast(user_param.num_class), 1u); - auto n_targets = user_param.num_target; - num_output_group = std::max(n_classes, n_targets); - // For version < 1.6, n_targets == 0 - CHECK(n_classes <= 1 || n_targets <= 1) - << "Multi-class multi-output is not yet supported. n_classes:" << n_classes - << ", n_targets:" << n_targets; -} + : num_feature{user_param.num_feature}, + num_output_group{user_param.multi_strategy == Strategy::kComposite + ? std::max(user_param.num_class, 1) + : 1u}, + num_target{user_param.multi_strategy == Strategy::kMono ? std::max(user_param.num_class, 1) + : 1u}, + task{t} {} LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy const& user_param, linalg::Tensor base_margin, ObjInfo t) @@ -269,6 +278,7 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) { num_feature = that.num_feature; num_output_group = that.num_output_group; + num_target = that.num_target; task = that.task; } @@ -348,7 +358,7 @@ class LearnerConfiguration : public Learner { this->ConfigureTargets(); auto task = UsePtr(obj_)->Task(); - linalg::Tensor base_score({1}, Ctx()->gpu_id); + linalg::Tensor base_score({1}, linalg::Order::kC, Ctx()->gpu_id); auto h_base_score = base_score.HostView(); // transform to margin @@ -724,7 +734,6 @@ class LearnerConfiguration : public Learner { << "0 feature is supplied. Are you using raw Booster interface?"; // Remove these once binary IO is gone. cfg_["num_feature"] = common::ToString(mparam_.num_feature); - cfg_["num_class"] = common::ToString(mparam_.num_class); } void ConfigureGBM(LearnerTrainParam const& old, Args const& args) { @@ -755,9 +764,17 @@ class LearnerConfiguration : public Learner { if (obj_ == nullptr || tparam_.objective != old.objective) { obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_)); } + + bool has_nc {cfg_.find("num_class") != cfg_.cend()}; + // Inject num_class into configuration. + // FIXME(jiamingy): Remove the duplicated parameter in softmax + cfg_["num_class"] = common::ToString(mparam_.num_class); auto& args = *p_args; args = {cfg_.cbegin(), cfg_.cend()}; // renew obj_->Configure(args); + if (!has_nc) { + cfg_.erase("num_class"); + } } void ConfigureMetrics(Args const& args) { @@ -780,8 +797,11 @@ class LearnerConfiguration : public Learner { */ void ConfigureTargets() { CHECK(this->obj_); + if (mparam_.num_class > 1) { + return; + } auto const& cache = this->GetPredictionCache()->Container(); - size_t n_targets = 1; + bst_target_t n_targets = 1; for (auto const& d : cache) { if (n_targets == 1) { n_targets = this->obj_->Targets(d.first->Info()); @@ -790,13 +810,12 @@ class LearnerConfiguration : public Learner { CHECK(n_targets == t || 1 == t) << "Inconsistent labels."; } } - if (mparam_.num_target != 1) { - CHECK(n_targets == 1 || n_targets == mparam_.num_target) + if (mparam_.num_class > 1) { + CHECK(n_targets == 1 || n_targets == static_cast(mparam_.num_class)) << "Inconsistent configuration of num_target. Configuration result from input data:" - << n_targets << ", configuration from parameter:" << mparam_.num_target; - } else { - mparam_.num_target = n_targets; + << n_targets << ", configuration from parameter:" << mparam_.num_class; } + mparam_.num_class = n_targets; } }; @@ -834,6 +853,9 @@ class LearnerIO : public LearnerConfiguration { auto const& gradient_booster = learner.at("gradient_booster"); name = get(gradient_booster["name"]); tparam_.UpdateAllowUnknown(Args{{"booster", name}}); + // fixme + learner_model_param_ = LearnerModelParam{mparam_, obj_->Task()}; + gbm_.reset( GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_)); gbm_->LoadModel(gradient_booster); @@ -950,9 +972,6 @@ class LearnerIO : public LearnerConfiguration { if (!DMLC_IO_NO_ENDIAN_SWAP) { mparam_ = mparam_.ByteSwap(); } - if (mparam_.num_target == 0) { - mparam_.num_target = 1; - } CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format"; CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format"; @@ -1005,6 +1024,7 @@ class LearnerIO : public LearnerConfiguration { ? std::numeric_limits::quiet_NaN() : obj_->ProbToMargin(mparam_.base_score)}, {1}, + linalg::Order::kC, Context::kCpuId}, obj_->Task()); @@ -1034,7 +1054,6 @@ class LearnerIO : public LearnerConfiguration { mparam_.major_version = std::get<0>(Version::Self()); mparam_.minor_version = std::get<1>(Version::Self()); - cfg_["num_class"] = common::ToString(mparam_.num_class); cfg_["num_feature"] = common::ToString(mparam_.num_feature); auto n = tparam_.__DICT__(); @@ -1047,6 +1066,8 @@ class LearnerIO : public LearnerConfiguration { // JSON serialization format. void SaveModel(dmlc::Stream* fo) const override { this->CheckModelInitialized(); + CHECK(!this->learner_model_param_.IsVectorLeaf()) + << "Please use JSON/UBJ format for model serialization with multi-output models."; LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify std::vector > extra_attr; @@ -1356,6 +1377,11 @@ class LearnerImpl : public LearnerIO { this->CheckModelInitialized(); return this->learner_model_param_.num_output_group; } + bst_target_t Targets() const override { + CHECK(!this->need_configuration_); + this->CheckModelInitialized(); + return this->learner_model_param_.num_target; + } XGBAPIThreadLocalEntry& GetThreadLocal() const override { return (*LearnerAPIThreadLocalStore::Get())[this]; diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 51224ff84562..9b099faf7b95 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -4,7 +4,8 @@ #include #include -#include +#include // std::size_t +#include #include #include @@ -18,9 +19,13 @@ #include "cpu_treeshap.h" // CalculateContributions #include "predict_fn.h" #include "xgboost/base.h" +#include "xgboost/context.h" #include "xgboost/data.h" #include "xgboost/host_device_vector.h" +#include "xgboost/learner.h" +#include "xgboost/linalg.h" #include "xgboost/logging.h" +#include "xgboost/multi_target_tree_model.h" // MultiTargetTree #include "xgboost/predictor.h" #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" @@ -30,24 +35,24 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(cpu_predictor); +namespace scalar { template bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat, - RegTree::CategoricalSplitMatrix const& cats) { - bst_node_t nid = 0; - while (!tree[nid].IsLeaf()) { - unsigned split_index = tree[nid].SplitIndex(); + CategoricalSplitMatrix const &cats) { + bst_node_t nidx{0}; + while (!tree[nidx].IsLeaf()) { + bst_feature_t split_index = tree[nidx].SplitIndex(); auto fvalue = feat.GetFvalue(split_index); - nid = GetNextNode( - tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats); + nidx = GetNextNode( + tree[nidx], nidx, fvalue, has_missing && feat.IsMissing(split_index), cats); } - return nid; + return nidx; } bst_float PredValue(const SparsePage::Inst &inst, const std::vector> &trees, - const std::vector &tree_info, int bst_group, - RegTree::FVec *p_feats, unsigned tree_begin, - unsigned tree_end) { + const std::vector &tree_info, int bst_group, RegTree::FVec *p_feats, + unsigned tree_begin, unsigned tree_end) { bst_float psum = 0.0f; p_feats->Fill(inst); for (size_t i = tree_begin; i < tree_end; ++i) { @@ -69,40 +74,91 @@ bst_float PredValue(const SparsePage::Inst &inst, } template -bst_float -PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, - RegTree::CategoricalSplitMatrix const& cats) { - const bst_node_t leaf = p_feats.HasMissing() ? - GetLeafIndex(tree, p_feats, cats) : - GetLeafIndex(tree, p_feats, cats); +bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, + CategoricalSplitMatrix const &cats) { + const bst_node_t leaf = p_feats.HasMissing() + ? GetLeafIndex(tree, p_feats, cats) + : GetLeafIndex(tree, p_feats, cats); return tree[leaf].LeafValue(); } void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, - const size_t tree_end, std::vector *out_preds, - const size_t predict_offset, const size_t num_group, - const std::vector &thread_temp, - const size_t offset, const size_t block_size) { - std::vector &preds = *out_preds; + const size_t tree_end, const size_t predict_offset, + const std::vector &thread_temp, const size_t offset, + const size_t block_size, linalg::TensorView out_predt) { for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { const size_t gid = model.tree_info[tree_id]; auto const &tree = *model.trees[tree_id]; - auto const& cats = tree.GetCategoriesMatrix(); + auto const &cats = tree.GetCategoriesMatrix(); auto has_categorical = tree.HasCategoricalSplit(); if (has_categorical) { for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += + out_predt(predict_offset + i, gid) += PredValueByOneTree(thread_temp[offset + i], tree, cats); } } else { for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += - PredValueByOneTree(thread_temp[offset + i], tree, cats); + out_predt(predict_offset + i, gid) += + PredValueByOneTree(thread_temp[offset + i], tree, cats); + } + } + } +} +} // namespace scalar + +namespace multi { +template +bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat, + CategoricalSplitMatrix const &cats) { + bst_node_t nidx{0}; + while (!tree.IsLeaf(nidx)) { + unsigned split_index = tree.SplitIndex(nidx); + auto fvalue = feat.GetFvalue(split_index); + nidx = GetNextNodeMulti( + tree, nidx, fvalue, has_missing && feat.IsMissing(split_index), cats); + } + return nidx; +} + +template +void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree, + CategoricalSplitMatrix const &cats, linalg::VectorView out_predt) { + bst_node_t const leaf = p_feats.HasMissing() + ? GetLeafIndex(tree, p_feats, cats) + : GetLeafIndex(tree, p_feats, cats); + auto leaf_value = tree.LeafValue(leaf); + assert(out_predt.Shape(0) == leaf_value.Shape(0) && "shape mismatch."); + for (size_t i = 0; i < leaf_value.Size(); ++i) { + out_predt(i) += leaf_value(i); + } +} + +void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, + const size_t tree_end, const size_t predict_offset, + const std::vector &thread_temp, const size_t offset, + const size_t block_size, linalg::TensorView out_predt) { + for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { + auto const &tree = *model.trees.at(tree_id); + auto cats = tree.GetCategoriesMatrix(); + bool has_categorical = tree.HasCategoricalSplit(); + + if (has_categorical) { + for (std::size_t i = 0; i < block_size; ++i) { + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, + t_predts); + } + } else { + for (std::size_t i = 0; i < block_size; ++i) { + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, + t_predts); } } } } +} // namespace multi template void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature, @@ -228,15 +284,13 @@ class AdapterView { }; template -void PredictBatchByBlockOfRowsKernel( - DataView batch, std::vector *out_preds, - gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end, - std::vector *p_thread_temp, int32_t n_threads) { +void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model, + int32_t tree_begin, int32_t tree_end, + std::vector *p_thread_temp, int32_t n_threads, + linalg::TensorView out_predt) { auto &thread_temp = *p_thread_temp; - int32_t const num_group = model.learner_model_param->num_output_group; - CHECK_EQ(model.param.size_leaf_vector, 0) - << "size_leaf_vector is enforced to 0 so far"; + CHECK_EQ(model.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far"; // parallel over local batch const auto nsize = static_cast(batch.Size()); const int num_feature = model.learner_model_param->num_feature; @@ -244,16 +298,19 @@ void PredictBatchByBlockOfRowsKernel( common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) { const size_t batch_offset = block_id * block_of_rows_size; - const size_t block_size = - std::min(nsize - batch_offset, block_of_rows_size); + const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size); const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size; - FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, - p_thread_temp); + FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp); // process block of rows through all trees to keep cache locality - PredictByAllTrees(model, tree_begin, tree_end, out_preds, - batch_offset + batch.base_rowid, num_group, thread_temp, - fvec_offset, block_size); + if (model.learner_model_param->IsVectorLeaf()) { + multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, + thread_temp, fvec_offset, block_size, out_predt); + } else { + scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, + thread_temp, fvec_offset, block_size, out_predt); + } + FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp); }); } @@ -294,40 +351,8 @@ class CPUPredictor : public Predictor { } } - void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin, - int32_t tree_end, std::vector *out_preds) const { - auto const n_threads = this->ctx_->Threads(); - - constexpr double kDensityThresh = .5; - size_t total = - std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_, static_cast(1)); - double density = static_cast(p_fmat->Info().num_nonzero_) / static_cast(total); - bool blocked = density > kDensityThresh; - - std::vector feat_vecs; - InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs); - std::vector workspace(p_fmat->Info().num_col_ * kUnroll * n_threads); - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - for (auto const &batch : p_fmat->GetBatches({})) { - if (blocked) { - PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, - out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); - } else { - PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, - out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); - } - } - } - void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { - if (!p_fmat->PageExists()) { - this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds); - return; - } - auto const n_threads = this->ctx_->Threads(); constexpr double kDensityThresh = .5; size_t total = @@ -337,16 +362,38 @@ class CPUPredictor : public Predictor { std::vector feat_vecs; InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs); - for (auto const &batch : p_fmat->GetBatches()) { - CHECK_EQ(out_preds->size(), - p_fmat->Info().num_row_ * model.learner_model_param->num_output_group); - if (blocked) { - PredictBatchByBlockOfRowsKernel( - SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); - } else { - PredictBatchByBlockOfRowsKernel( - SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); + std::size_t n_samples = p_fmat->Info().num_row_; + std::size_t n_groups = model.learner_model_param->OutputLength(); + CHECK_EQ(out_preds->size(), n_samples * n_groups); + linalg::TensorView out_predt{*out_preds, {n_samples, n_groups}, ctx_->gpu_id}; + + if (!p_fmat->PageExists()) { + std::vector workspace(p_fmat->Info().num_col_ * kUnroll * n_threads); + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &batch : p_fmat->GetBatches({})) { + if (blocked) { + PredictBatchByBlockOfRowsKernel( + GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model, + tree_begin, tree_end, &feat_vecs, n_threads, out_predt); + } else { + PredictBatchByBlockOfRowsKernel( + GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model, + tree_begin, tree_end, &feat_vecs, n_threads, out_predt); + } + } + } else { + for (auto const &batch : p_fmat->GetBatches()) { + if (blocked) { + PredictBatchByBlockOfRowsKernel( + SparsePageView{&batch}, model, tree_begin, tree_end, &feat_vecs, n_threads, + out_predt); + + } else { + PredictBatchByBlockOfRowsKernel(SparsePageView{&batch}, model, + tree_begin, tree_end, &feat_vecs, + n_threads, out_predt); + } } } } @@ -354,24 +401,22 @@ class CPUPredictor : public Predictor { public: explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} - void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, - const gbm::GBTreeModel &model, uint32_t tree_begin, - uint32_t tree_end = 0) const override { - auto* out_preds = &predts->predictions; + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model, + uint32_t tree_begin, uint32_t tree_end = 0) const override { + auto *out_preds = &predts->predictions; // This is actually already handled in gbm, but large amount of tests rely on the // behaviour. if (tree_end == 0) { tree_end = model.trees.size(); } - this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, - tree_end); + this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, tree_end); } template void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, - uint32_t tree_begin, uint32_t tree_end) const { + PredictionCacheEntry *out_preds, uint32_t tree_begin, + uint32_t tree_end) const { auto const n_threads = this->ctx_->Threads(); auto m = dmlc::get>(x); CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) @@ -384,13 +429,16 @@ class CPUPredictor : public Predictor { info.num_row_ = m->NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); } + std::vector workspace(m->NumColumns() * kUnroll * n_threads); auto &predictions = out_preds->predictions.HostVector(); std::vector thread_temp; InitThreadTemp(n_threads * kBlockSize, &thread_temp); + std::size_t n_groups = model.learner_model_param->OutputLength(); + linalg::TensorView out_predt{predictions, {m->NumRows(), n_groups}, Context::kCpuId}; PredictBatchByBlockOfRowsKernel, kBlockSize>( - AdapterView(m.get(), missing, common::Span{workspace}, n_threads), - &predictions, model, tree_begin, tree_end, &thread_temp, n_threads); + AdapterView(m.get(), missing, common::Span{workspace}, n_threads), model, + tree_begin, tree_end, &thread_temp, n_threads, out_predt); } bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, @@ -420,6 +468,7 @@ class CPUPredictor : public Predictor { void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented(); std::vector feat_vecs; feat_vecs.resize(1, RegTree::FVec()); feat_vecs[0].Init(model.learner_model_param->num_feature); @@ -432,31 +481,30 @@ class CPUPredictor : public Predictor { auto base_score = model.learner_model_param->BaseScore(ctx_)(0); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { - (*out_preds)[gid] = - PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], 0, ntree_limit) + - base_score; + (*out_preds)[gid] = scalar::PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], + 0, ntree_limit) + + base_score; } } - void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, + const gbm::GBTreeModel &model, unsigned ntree_limit) const override { auto const n_threads = this->ctx_->Threads(); std::vector feat_vecs; const int num_feature = model.learner_model_param->num_feature; InitThreadTemp(n_threads, &feat_vecs); - const MetaInfo& info = p_fmat->Info(); + const MetaInfo &info = p_fmat->Info(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } - std::vector& preds = out_preds->HostVector(); + std::vector &preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { // parallel over local batch auto page = batch.GetView(); - const auto nsize = static_cast(batch.Size()); - common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { + common::ParallelFor(page.Size(), n_threads, [&](auto i) { const int tid = omp_get_thread_num(); auto ridx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = feat_vecs[tid]; @@ -464,23 +512,28 @@ class CPUPredictor : public Predictor { feats.Init(num_feature); } feats.Fill(page[i]); - for (unsigned j = 0; j < ntree_limit; ++j) { - auto const& tree = *model.trees[j]; - auto const& cats = tree.GetCategoriesMatrix(); - bst_node_t tid = GetLeafIndex(tree, feats, cats); - preds[ridx * ntree_limit + j] = static_cast(tid); + for (std::uint32_t j = 0; j < ntree_limit; ++j) { + auto const &tree = *model.trees[j]; + auto const &cats = tree.GetCategoriesMatrix(); + bst_node_t nidx; + if (tree.IsMultiTarget()) { + nidx = multi::GetLeafIndex(*tree.GetMultiTargetTree(), feats, cats); + } else { + nidx = scalar::GetLeafIndex(tree, feats, cats); + } + preds[ridx * ntree_limit + j] = static_cast(nidx); } feats.Drop(page[i]); }); } } - void PredictContribution(DMatrix *p_fmat, - HostDeviceVector *out_contribs, + void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, const gbm::GBTreeModel &model, uint32_t ntree_limit, - std::vector const *tree_weights, - bool approximate, int condition, - unsigned condition_feature) const override { + std::vector const *tree_weights, bool approximate, + int condition, unsigned condition_feature) const override { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "predict contribution" << MTNotImplemented(); auto const n_threads = this->ctx_->Threads(); const int num_feature = model.learner_model_param->num_feature; std::vector feat_vecs; @@ -556,11 +609,12 @@ class CPUPredictor : public Predictor { } } - void PredictInteractionContributions( - DMatrix *p_fmat, HostDeviceVector *out_contribs, - const gbm::GBTreeModel &model, unsigned ntree_limit, - std::vector const *tree_weights, - bool approximate) const override { + void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, + const gbm::GBTreeModel &model, unsigned ntree_limit, + std::vector const *tree_weights, + bool approximate) const override { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "predict interaction contribution" << MTNotImplemented(); const MetaInfo& info = p_fmat->Info(); const int ngroup = model.learner_model_param->num_output_group; size_t const ncolumns = model.learner_model_param->num_feature; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 35daf701c9d3..c2502f667b71 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -31,16 +31,15 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(gpu_predictor); struct TreeView { - RegTree::CategoricalSplitMatrix cats; + CategoricalSplitMatrix cats; common::Span d_tree; XGBOOST_DEVICE - TreeView(size_t tree_begin, size_t tree_idx, - common::Span d_nodes, + TreeView(size_t tree_begin, size_t tree_idx, common::Span d_nodes, common::Span d_tree_segments, common::Span d_tree_split_types, common::Span d_cat_tree_segments, - common::Span d_cat_node_segments, + common::Span d_cat_node_segments, common::Span d_categories) { auto begin = d_tree_segments[tree_idx - tree_begin]; auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] - @@ -255,7 +254,7 @@ PredictLeafKernel(Data data, common::Span d_nodes, common::Span d_tree_split_types, common::Span d_cat_tree_segments, - common::Span d_cat_node_segments, + common::Span d_cat_node_segments, common::Span d_categories, size_t tree_begin, size_t tree_end, size_t num_features, @@ -290,7 +289,7 @@ PredictKernel(Data data, common::Span d_nodes, common::Span d_tree_group, common::Span d_tree_split_types, common::Span d_cat_tree_segments, - common::Span d_cat_node_segments, + common::Span d_cat_node_segments, common::Span d_categories, size_t tree_begin, size_t tree_end, size_t num_features, size_t num_rows, size_t entry_start, bool use_shared, int num_group, float missing) { @@ -334,7 +333,7 @@ class DeviceModel { // Pointer to each tree, segmenting the node array. HostDeviceVector categories_tree_segments; // Pointer to each node, segmenting categories array. - HostDeviceVector categories_node_segments; + HostDeviceVector categories_node_segments; HostDeviceVector categories; size_t tree_beg_; // NOLINT @@ -401,8 +400,8 @@ class DeviceModel { } categories_node_segments = - HostDeviceVector(h_tree_segments.back(), {}, gpu_id); - std::vector &h_categories_node_segments = + HostDeviceVector(h_tree_segments.back(), {}, gpu_id); + std::vector &h_categories_node_segments = categories_node_segments.HostVector(); for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr(); @@ -542,10 +541,10 @@ void ExtractPaths( if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types), common::IsCatOp{})) { dh::PinnedMemory pinned; - auto h_max_cat = pinned.GetSpan(1); + auto h_max_cat = pinned.GetSpan(1); auto max_elem_it = dh::MakeTransformIterator( dh::tbegin(d_cat_node_segments), - [] __device__(RegTree::Segment seg) { return seg.size; }); + [] __device__(CategoricalSplitMatrix::Segment seg) { return seg.size; }); size_t max_cat_it = thrust::max_element(thrust::device, max_elem_it, max_elem_it + d_cat_node_segments.size()) - diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index 5d0c175fcf65..c470fd2d0234 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -11,7 +11,7 @@ namespace predictor { template inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, bool is_missing, - RegTree::CategoricalSplitMatrix const &cats) { + CategoricalSplitMatrix const &cats) { if (has_missing && is_missing) { return node.DefaultChild(); } else { @@ -24,6 +24,25 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs } } } + +template +inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree, + bst_node_t const nidx, float fvalue, + bool is_missing, + CategoricalSplitMatrix const &cats) { + if (has_missing && is_missing) { + return tree.DefaultChild(nidx); + } else { + if (has_categorical && common::IsCat(cats.split_type, nidx)) { + auto node_categories = + cats.categories.subspan(cats.node_ptr[nidx].beg, cats.node_ptr[nidx].size); + return common::Decision(node_categories, fvalue) ? tree.LeftChild(nidx) : tree.RightChild(nidx); + } else { + return tree.LeftChild(nidx) + !(fvalue < tree.SplitCond(nidx)); + } + } +} + } // namespace predictor } // namespace xgboost #endif // XGBOOST_PREDICTOR_PREDICT_FN_H_ diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 38ac9492f4d5..dc0e034c7e96 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -74,15 +74,24 @@ void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_row_t n void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { CHECK_NE(model.learner_model_param->num_output_group, 0); - size_t n_classes = model.learner_model_param->num_output_group; - size_t n = n_classes * info.num_row_; + size_t n; + if (model.learner_model_param->IsVectorLeaf()) { + std::size_t n_targets = model.learner_model_param->num_target; + n = n_targets * info.num_row_; + } else { + size_t n_classes = model.learner_model_param->num_output_group; + n = n_classes * info.num_row_; + } + const HostDeviceVector* base_margin = info.base_margin_.Data(); if (ctx_->gpu_id >= 0) { out_preds->SetDevice(ctx_->gpu_id); } if (!base_margin->Empty()) { out_preds->Resize(n); - ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); + std::cout << "predt size:" << n << std::endl; + ValidateBaseMarginShape(info.base_margin_, info.num_row_, + model.learner_model_param->num_output_group); out_preds->Copy(*base_margin); } else { // cannot rely on the Resize to fill as it might skip if the size is already correct. diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index a5f4aac2d58d..326379e73582 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -6,13 +6,18 @@ #ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ -#include // std::numeric_limits +#include // std::all_of +#include // std::uint32_t +#include // std::numeric_limits #include +#include "../common/linalg_op.h" // cbegin #include "../common/numeric.h" // Iota #include "../common/partition_builder.h" #include "hist/expand_entry.h" // CPUExpandEntry -#include "xgboost/context.h" // Context +#include "xgboost/base.h" +#include "xgboost/context.h" // Context +#include "xgboost/linalg.h" // TensorView namespace xgboost { namespace tree { @@ -36,14 +41,16 @@ class CommonRowPartitioner { row_set_collection_.Init(); } - void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + template + void FindSplitConditions(const std::vector& nodes, const RegTree& tree, const GHistIndexMatrix& gmat, std::vector* split_conditions) { + auto const& ptrs = gmat.cut.Ptrs(); for (size_t i = 0; i < nodes.size(); ++i) { - const int32_t nid = nodes[i].nid; - const bst_uint fid = tree[nid].SplitIndex(); - const bst_float split_pt = tree[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + bst_node_t const nidx = nodes[i].nid; + bst_feature_t const fidx = tree.SplitIndex(nidx); + float const split_pt = tree.SplitCond(nidx); + std::uint32_t const lower_bound = ptrs[fidx]; + std::uint32_t const upper_bound = ptrs[fidx + 1]; bst_bin_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points @@ -57,20 +64,22 @@ class CommonRowPartitioner { } } - void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree) { + template + void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree) { const size_t n_nodes = nodes.size(); for (unsigned int i = 0; i < n_nodes; ++i) { - const int32_t nid = nodes[i].nid; + const int32_t nidx = nodes[i].nid; const size_t n_left = partition_builder_.GetNLeftElems(i); const size_t n_right = partition_builder_.GetNRightElems(i); - CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); - row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), - n_left, n_right); + CHECK_EQ(p_tree->LeftChild(nidx) + 1, p_tree->RightChild(nidx)); + row_set_collection_.AddSplit(nidx, p_tree->LeftChild(nidx), p_tree->RightChild(nidx), n_left, + n_right); } } + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { auto const& column_matrix = gmat.Transpose(); if (column_matrix.IsInitialized()) { if (gmat.cut.HasCategorical()) { @@ -88,10 +97,10 @@ class CommonRowPartitioner { } } - template + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { if (column_matrix.AnyMissing()) { this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); } else { @@ -99,33 +108,21 @@ class CommonRowPartitioner { } } - template + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { - switch (column_matrix.GetTypeSize()) { - case common::kUint8BinsTypeSize: - this->template UpdatePosition(ctx, gmat, column_matrix, - nodes, p_tree); - break; - case common::kUint16BinsTypeSize: - this->template UpdatePosition(ctx, gmat, column_matrix, - nodes, p_tree); - break; - case common::kUint32BinsTypeSize: - this->template UpdatePosition(ctx, gmat, column_matrix, - nodes, p_tree); - break; - default: - // no default behavior - CHECK(false) << column_matrix.GetTypeSize(); - } + std::vector const& nodes, RegTree const* p_tree) { + common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto t) { + using T = decltype(t); + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, + p_tree); + }); } - template + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { // 1. Find split condition for each split size_t n_nodes = nodes.size(); @@ -199,11 +196,20 @@ class CommonRowPartitioner { } void LeafPartition(Context const* ctx, RegTree const& tree, - common::Span gpair, + linalg::TensorView gpair, std::vector* p_out_position) const { - partition_builder_.LeafPartition( - ctx, tree, this->Partitions(), p_out_position, - [&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); + if (gpair.Shape(1) > 1) { + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), p_out_position, [&](size_t idx) -> bool { + auto sample = gpair.Slice(idx, linalg::All()); + return std::all_of(linalg::cbegin(sample), linalg::cend(sample), + [](GradientPair const& g) { return g.GetHess() - .0f == .0f; }); + }); + } else { + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), p_out_position, + [&](size_t idx) -> bool { return gpair(idx, 0).GetHess() - .0f == .0f; }); + } } }; diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 82efff2c77ac..f4de8e60c438 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -71,10 +71,8 @@ void FitStump(Context const* ctx, HostDeviceVector const& gpair, auto n_samples = gpair.Size() / n_targets; gpair.SetDevice(ctx->gpu_id); - linalg::TensorView gpair_t{ - ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(), - {n_samples, n_targets}, - ctx->gpu_id}; + auto gpair_t = linalg::MakeTensorView( + ctx, ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(), n_samples, n_targets); ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView()) : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index f76565e9a98e..04d266be6b57 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -4,21 +4,24 @@ #ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ -#include -#include -#include -#include -#include -#include +#include // std::copy +#include // numeric_limits +#include // shared_ptr +#include // accumulate +#include // move +#include // vector #include "../../common/categorical.h" -#include "../../common/hist_util.h" +#include "../../common/hist_util.h" // GHistRow,HistogramCuts +#include "../../common/linalg_op.h" // cbegin,cend,begin #include "../../common/random.h" -#include "../../data/gradient_index.h" #include "../constraints.h" #include "../param.h" #include "../split_evaluator.h" +#include "expand_entry.h" // MultiExpandEntry +#include "xgboost/base.h" #include "xgboost/context.h" +#include "xgboost/linalg.h" // Constants,Tensor namespace xgboost { namespace tree { @@ -391,8 +394,6 @@ class HistEvaluator { tree[candidate.nid].SplitIndex(), left_weight, right_weight); - auto max_node = std::max(left_child, tree[candidate.nid].RightChild()); - max_node = std::max(candidate.nid, max_node); snode_.resize(tree.GetNodes().size()); snode_.at(left_child).stats = candidate.split.left_sum; snode_.at(left_child).root_gain = evaluator.CalcGain( @@ -409,15 +410,14 @@ class HistEvaluator { auto Evaluator() const { return tree_evaluator_.GetEvaluator(); } auto const& Stats() const { return snode_; } - float InitRoot(GradStats const& root_sum) { + float InitRoot(GradStats const &root_sum) { snode_.resize(1); auto root_evaluator = tree_evaluator_.GetEvaluator(); snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()}; - snode_[0].root_gain = root_evaluator.CalcGain(RegTree::kRoot, param_, - GradStats{snode_[0].stats}); - auto weight = root_evaluator.CalcWeight(RegTree::kRoot, param_, - GradStats{snode_[0].stats}); + snode_[0].root_gain = + root_evaluator.CalcGain(RegTree::kRoot, param_, GradStats{snode_[0].stats}); + auto weight = root_evaluator.CalcWeight(RegTree::kRoot, param_, GradStats{snode_[0].stats}); return weight; } @@ -436,6 +436,211 @@ class HistEvaluator { } }; +class HistMultiEvaluator { + std::vector gain_; + linalg::Matrix stats_; + TrainParam param_; + FeatureInteractionConstraintHost interaction_constraints_; + std::shared_ptr column_sampler_; + Context const *ctx_; + + private: + static double MultiCalcSplitGain(TrainParam const ¶m, + linalg::VectorView left_sum, + linalg::VectorView right_sum, + linalg::VectorView left_weight, + linalg::VectorView right_weight) { + CalcWeight(param, left_sum, left_weight); + CalcWeight(param, right_sum, right_weight); + + auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight); + auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight); + return left_gain + right_gain; + } + + template + bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx, + common::Span hist, + linalg::VectorView parent_sum, double parent_gain, + SplitEntryContainer> *p_best) const { + auto const &cut_ptr = cut.Ptrs(); + auto const &cut_val = cut.Values(); + auto const &min_val = cut.MinValues(); + + auto sum = linalg::Empty(ctx_, 2, hist.size()); + auto left_sum = sum.HostView().Slice(0, linalg::All()); + auto right_sum = sum.HostView().Slice(1, linalg::All()); + + bst_bin_t ibegin, iend; + if (d_step > 0) { + ibegin = static_cast(cut_ptr[fidx]); + iend = static_cast(cut_ptr.at(fidx + 1)); + } else { + ibegin = static_cast(cut_ptr[fidx + 1]) - 1; + iend = static_cast(cut_ptr[fidx]) - 1; + } + const auto imin = static_cast(cut_ptr[fidx]); + + auto weight = linalg::Empty(ctx_, 2, hist.size()); + auto left_weight = weight.Slice(0, linalg::All()); + auto right_weight = weight.Slice(1, linalg::All()); + + for (bst_bin_t i = ibegin; i != iend; i += d_step) { + for (bst_target_t t = 0; t < hist.size(); ++t) { + auto t_hist = hist[t]; + auto t_p = parent_sum(t); + left_sum(t) += t_hist[i]; + right_sum(t) = t_p - left_sum(t); + } + + if (d_step > 0) { + auto split_pt = cut_val[i]; + auto loss_chg = MultiCalcSplitGain(param_, right_sum, left_sum, right_weight, left_weight) - + parent_gain; + p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); + } else { + float split_pt; + if (i == imin) { + split_pt = min_val[fidx]; + } else { + split_pt = cut_val[i - 1]; + } + auto loss_chg = MultiCalcSplitGain(param_, right_sum, left_sum, left_weight, right_weight) - + parent_gain; + p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum); + } + } + // return true if there's missing. Doesn't handle floating-point error well. + if (d_step == +1) { + return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum), + linalg::cbegin(parent_sum)); + } + return false; + } + + public: + void EvaluateSplits(RegTree const &tree, common::Span hist, + common::HistogramCuts const &cut, std::vector *p_entries) { + auto &entries = *p_entries; + std::vector>> features(entries.size()); + + for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { + auto nidx = entries[nidx_in_set].nid; + features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx)); + } + CHECK(!features.empty()); + + std::int32_t n_threads = ctx_->Threads(); + std::size_t const grain_size = std::max(1, features.front()->Size() / n_threads); + common::BlockedSpace2d space( + entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); }, + grain_size); + + std::vector tloc_candidates(n_threads * entries.size()); + for (std::size_t i = 0; i < entries.size(); ++i) { + for (std::int32_t j = 0; j < n_threads; ++j) { + tloc_candidates[i * n_threads + j] = entries[i]; + } + } + common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) { + auto tidx = omp_get_thread_num(); + auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx]; + auto best = &entry->split; + auto parent_sum = stats_.Slice(entry->nid, linalg::All()); + std::vector node_hist; + for (auto t_hist : hist) { + node_hist.push_back((*t_hist)[entry->nid]); + } + auto features_set = features[nidx_in_set]->ConstHostSpan(); + + for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) { + auto fidx = features_set[fidx_in_set]; + if (!interaction_constraints_.Query(entry->nid, fidx)) { + continue; + } + auto parent_gain = gain_[entry->nid]; + bool missing = this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best); + if (missing) { + this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best); + } + } + }); + + for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { + for (auto tidx = 0; tidx < n_threads; ++tidx) { + entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split); + } + } + } + + linalg::Tensor InitRoot(linalg::VectorView root_sum) { + auto n_targets = root_sum.Size(); + stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets); + gain_.resize(1); + + linalg::Tensor weight({n_targets}, linalg::Order::kC, ctx_->gpu_id); + CalcWeight(param_, root_sum, weight.HostView()); + auto root_gain = CalcGainGivenWeight(param_, root_sum, weight.HostView()); + gain_.front() = root_gain; + + auto h_stats = stats_.HostView(); + std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats)); + + return weight; + } + + void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) { + auto n_targets = p_tree->NumTargets(); + auto parent_sum = stats_.Slice(candidate.nid, linalg::All()); + + auto weight = linalg::Empty(ctx_, 3, n_targets); + auto base_weight = weight.Slice(0, linalg::All()); + CalcWeight(param_, parent_sum, base_weight); + + auto left_weight = weight.Slice(1, linalg::All()); + auto left_sum = + linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size()); + CalcWeight(param_, left_sum, left_weight); + + auto right_weight = weight.Slice(2, linalg::All()); + auto right_sum = + linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size()); + CalcWeight(param_, right_sum, right_weight); + + p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value, + candidate.split.DefaultLeft(), base_weight, left_weight, right_weight); + CHECK(p_tree->IsMultiTarget()); + auto left_child = p_tree->LeftChild(candidate.nid); + CHECK_GT(left_child, candidate.nid); + auto right_child = p_tree->RightChild(candidate.nid); + CHECK_GT(right_child, candidate.nid); + + std::size_t n_nodes = p_tree->Size(); + gain_.resize(n_nodes); + gain_[left_child] = CalcGainGivenWeight(param_, left_sum, left_weight); + gain_[right_child] = CalcGainGivenWeight(param_, right_sum, right_weight); + + if (n_nodes >= stats_.Shape(0)) { + stats_.Reshape(n_nodes * 2, stats_.Shape(1)); + } + CHECK_EQ(stats_.Shape(1), n_targets); + auto left_sum_stat = stats_.Slice(left_child, linalg::All()); + std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(), + linalg::begin(left_sum_stat)); + auto right_sum_stat = stats_.Slice(right_child, linalg::All()); + std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(), + linalg::begin(right_sum_stat)); + } + + explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam param, + std::shared_ptr sampler) + : param_{std::move(param)}, column_sampler_{std::move(sampler)}, ctx_{ctx} { + interaction_constraints_.Configure(param, info.num_col_); + column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode, + param_.colsample_bylevel, param_.colsample_bytree); + } +}; + /** * \brief CPU implementation of update prediction cache, which calculates the leaf value * for the last tree and accumulates it to prediction vector. diff --git a/src/tree/hist/expand_entry.h b/src/tree/hist/expand_entry.h index 885a109bfabe..2c08678e2207 100644 --- a/src/tree/hist/expand_entry.h +++ b/src/tree/hist/expand_entry.h @@ -4,26 +4,45 @@ #ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ #define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ +#include // all_of #include + #include "../param.h" +#include "xgboost/base.h" namespace xgboost { namespace tree { +/** + * \brief Structure for storing tree split candidate. + */ +template +struct ExpandEntryImpl { + bst_node_t nid; + bst_node_t depth; + + float GetLossChange() const { return static_cast(this)->split.loss_chg; } + bst_node_t GetNodeId() const { return nid; } + + static bool ChildIsValid(TrainParam const& param, bst_node_t depth, bst_node_t num_leaves) { + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; + return true; + } + + bool IsValid(TrainParam const& param, bst_node_t num_leaves) const { + return static_cast(this)->IsValidImpl(param, num_leaves); + } +}; -struct CPUExpandEntry { - int nid; - int depth; +struct CPUExpandEntry : public ExpandEntryImpl { SplitEntry split; + CPUExpandEntry() = default; - XGBOOST_DEVICE - CPUExpandEntry(int nid, int depth, SplitEntry split) - : nid(nid), depth(depth), split(std::move(split)) {} - CPUExpandEntry(int nid, int depth, float loss_chg) - : nid(nid), depth(depth) { - split.loss_chg = loss_chg; - } + CPUExpandEntry(bst_node_t nidx, bst_node_t depth, SplitEntry split) + : ExpandEntryImpl{nidx, depth}, split(std::move(split)) {} + CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {} - bool IsValid(const TrainParam& param, int num_leaves) const { + bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const { if (split.loss_chg <= kRtEps) return false; if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { return false; @@ -40,21 +59,61 @@ struct CPUExpandEntry { return true; } - float GetLossChange() const { return split.loss_chg; } - bst_node_t GetNodeId() const { return nid; } + friend std::ostream& operator<<(std::ostream& os, CPUExpandEntry const& e) { + os << "ExpandEntry:\n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "split:\n" << e.split << std::endl; + return os; + } +}; - static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { - if (param.max_depth > 0 && depth >= param.max_depth) return false; - if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; +struct MultiExpandEntry : public ExpandEntryImpl { + SplitEntryContainer> split; + + MultiExpandEntry() = default; + MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {} + + bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + auto is_zero = [](auto const& sum) { + return std::all_of(sum.cbegin(), sum.cend(), + [&](auto const& g) { return g.GetHess() - .0 == .0; }); + }; + if (is_zero(split.left_sum) || is_zero(split.right_sum)) { + return false; + } + if (split.loss_chg < param.min_split_loss) { + return false; + } + if (param.max_depth > 0 && depth == param.max_depth) { + return false; + } + if (param.max_leaves > 0 && num_leaves == param.max_leaves) { + return false; + } return true; } - friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) { - os << "ExpandEntry:\n"; + friend std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& e) { + os << "ExpandEntry: \n"; os << "nidx: " << e.nid << "\n"; os << "depth: " << e.depth << "\n"; os << "loss: " << e.split.loss_chg << "\n"; - os << "split:\n" << e.split << std::endl; + os << "split cond:" << e.split.split_value << "\n"; + os << "split ind:" << e.split.SplitIndex() << "\n"; + os << "left_sum: ["; + for (auto v : e.split.left_sum) { + os << v << ", "; + } + os << "]\n"; + + os << "right_sum: ["; + for (auto v : e.split.right_sum) { + os << v << ", "; + } + os << "]\n"; return os; } }; diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index acc13f6817a5..fd70ddff5d99 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -12,7 +12,7 @@ #include "../../common/hist_util.h" #include "../../data/gradient_index.h" #include "expand_entry.h" -#include "xgboost/tree_model.h" +#include "xgboost/tree_model.h" // RegTree namespace xgboost { namespace tree { @@ -59,15 +59,14 @@ class HistogramBuilder { GHistIndexMatrix const &gidx, std::vector const &nodes_for_explicit_hist_build, common::RowSetCollection const &row_set_collection, - const std::vector &gpair_h, - bool force_read_by_column) { + common::Span gpair_h, bool force_read_by_column) { const size_t n_nodes = nodes_for_explicit_hist_build.size(); CHECK_GT(n_nodes, 0); std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { - const int32_t nid = nodes_for_explicit_hist_build[i].nid; - target_hists[i] = hist_[nid]; + auto const nidx = nodes_for_explicit_hist_build[i].nid; + target_hists[i] = hist_[nidx]; } if (page_idx == 0) { // FIXME(jiamingy): Handle different size of space. Right now we use the maximum @@ -93,46 +92,37 @@ class HistogramBuilder { }); } - void - AddHistRows(int *starting_index, int *sync_count, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick, - RegTree *p_tree) { + void AddHistRows(int *starting_index, int *sync_count, + std::vector const &nodes_for_explicit_hist_build, + std::vector const &nodes_for_subtraction_trick, + RegTree const *p_tree) { if (is_distributed_) { - this->AddHistRowsDistributed(starting_index, sync_count, - nodes_for_explicit_hist_build, + this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, p_tree); } else { - this->AddHistRowsLocal(starting_index, sync_count, - nodes_for_explicit_hist_build, + this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build, nodes_for_subtraction_trick); } } /** Main entry point of this class, build histogram for tree nodes. */ void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx, - RegTree *p_tree, common::RowSetCollection const &row_set_collection, + RegTree const *p_tree, common::RowSetCollection const &row_set_collection, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick, - std::vector const &gpair, - bool force_read_by_column = false) { + common::Span gpair, bool force_read_by_column = false) { int starting_index = std::numeric_limits::max(); int sync_count = 0; if (page_id == 0) { - this->AddHistRows(&starting_index, &sync_count, - nodes_for_explicit_hist_build, + this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, p_tree); } if (gidx.IsDense()) { - this->BuildLocalHistograms(page_id, space, gidx, - nodes_for_explicit_hist_build, - row_set_collection, gpair, - force_read_by_column); + this->BuildLocalHistograms(page_id, space, gidx, nodes_for_explicit_hist_build, + row_set_collection, gpair, force_read_by_column); } else { - this->BuildLocalHistograms(page_id, space, gidx, - nodes_for_explicit_hist_build, - row_set_collection, gpair, - force_read_by_column); + this->BuildLocalHistograms(page_id, space, gidx, nodes_for_explicit_hist_build, + row_set_collection, gpair, force_read_by_column); } CHECK_GE(n_batches_, 1); @@ -153,8 +143,7 @@ class HistogramBuilder { common::RowSetCollection const &row_set_collection, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick, - std::vector const &gpair, - bool force_read_by_column = false) { + common::Span gpair, bool force_read_by_column = false) { const size_t n_nodes = nodes_for_explicit_hist_build.size(); // create space of size (# rows in each node) common::BlockedSpace2d space( @@ -164,83 +153,72 @@ class HistogramBuilder { return row_set_collection[nidx].Size(); }, 256); - this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, - nodes_for_explicit_hist_build, nodes_for_subtraction_trick, - gpair, force_read_by_column); + this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, nodes_for_explicit_hist_build, + nodes_for_subtraction_trick, gpair, force_read_by_column); } - void SyncHistogramDistributed( - RegTree *p_tree, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick, - int starting_index, int sync_count) { + void SyncHistogramDistributed(RegTree const *p_tree, + std::vector const &nodes_for_explicit_hist_build, + std::vector const &nodes_for_subtraction_trick, + int starting_index, int sync_count) { const size_t nbins = builder_.GetNumBins(); common::BlockedSpace2d space( - nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, - 1024); - common::ParallelFor2d( - space, n_threads_, [&](size_t node, common::Range1d r) { - const auto &entry = nodes_for_explicit_hist_build[node]; - auto this_hist = this->hist_[entry.nid]; - // Merging histograms from each thread into once - buffer_.ReduceHist(node, r.begin(), r.end()); - // Store posible parent node - auto this_local = hist_local_worker_[entry.nid]; - common::CopyHist(this_local, this_hist, r.begin(), r.end()); + nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024); + common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) { + const auto &entry = nodes_for_explicit_hist_build[node]; + auto this_hist = this->hist_[entry.nid]; + // Merging histograms from each thread into once + buffer_.ReduceHist(node, r.begin(), r.end()); + // Store posible parent node + auto this_local = hist_local_worker_[entry.nid]; + common::CopyHist(this_local, this_hist, r.begin(), r.end()); - if (!(*p_tree)[entry.nid].IsRoot()) { - const size_t parent_id = (*p_tree)[entry.nid].Parent(); - const int subtraction_node_id = - nodes_for_subtraction_trick[node].nid; - auto parent_hist = this->hist_local_worker_[parent_id]; - auto sibling_hist = this->hist_[subtraction_node_id]; - common::SubtractionHist(sibling_hist, parent_hist, this_hist, - r.begin(), r.end()); - // Store posible parent node - auto sibling_local = hist_local_worker_[subtraction_node_id]; - common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end()); - } - }); + if (!p_tree->IsRoot(entry.nid)) { + const size_t parent_id = p_tree->Parent(entry.nid); + const int subtraction_node_id = nodes_for_subtraction_trick[node].nid; + auto parent_hist = this->hist_local_worker_[parent_id]; + auto sibling_hist = this->hist_[subtraction_node_id]; + common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end()); + // Store posible parent node + auto sibling_local = hist_local_worker_[subtraction_node_id]; + common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end()); + } + }); collective::Allreduce( reinterpret_cast(this->hist_[starting_index].data()), builder_.GetNumBins() * sync_count * 2); - ParallelSubtractionHist(space, nodes_for_explicit_hist_build, - nodes_for_subtraction_trick, p_tree); + ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, + p_tree); common::BlockedSpace2d space2( - nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; }, - 1024); - ParallelSubtractionHist(space2, nodes_for_subtraction_trick, - nodes_for_explicit_hist_build, p_tree); + nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; }, 1024); + ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build, + p_tree); } - void SyncHistogramLocal(RegTree *p_tree, + void SyncHistogramLocal(RegTree const *p_tree, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick) { const size_t nbins = this->builder_.GetNumBins(); common::BlockedSpace2d space( - nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, - 1024); + nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024); - common::ParallelFor2d( - space, this->n_threads_, [&](size_t node, common::Range1d r) { - const auto &entry = nodes_for_explicit_hist_build[node]; - auto this_hist = this->hist_[entry.nid]; - // Merging histograms from each thread into once - this->buffer_.ReduceHist(node, r.begin(), r.end()); + common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) { + const auto &entry = nodes_for_explicit_hist_build[node]; + auto this_hist = this->hist_[entry.nid]; + // Merging histograms from each thread into once + this->buffer_.ReduceHist(node, r.begin(), r.end()); - if (!(*p_tree)[entry.nid].IsRoot()) { - const size_t parent_id = (*p_tree)[entry.nid].Parent(); - const int subtraction_node_id = - nodes_for_subtraction_trick[node].nid; - auto parent_hist = this->hist_[parent_id]; - auto sibling_hist = this->hist_[subtraction_node_id]; - common::SubtractionHist(sibling_hist, parent_hist, this_hist, - r.begin(), r.end()); - } - }); + if (!p_tree->IsRoot(entry.nid)) { + auto const parent_id = p_tree->Parent(entry.nid); + auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid; + auto parent_hist = this->hist_[parent_id]; + auto sibling_hist = this->hist_[subtraction_node_id]; + common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end()); + } + }); } public: @@ -257,10 +235,10 @@ class HistogramBuilder { common::ParallelFor2d( space, this->n_threads_, [&](size_t node, common::Range1d r) { const auto &entry = nodes[node]; - if (!((*p_tree)[entry.nid].IsLeftChild())) { + if (!(p_tree->IsLeftChild(entry.nid))) { auto this_hist = this->hist_[entry.nid]; - if (!(*p_tree)[entry.nid].IsRoot()) { + if (!p_tree->IsRoot(entry.nid)) { const int subtraction_node_id = subtraction_nodes[node].nid; auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()]; auto sibling_hist = hist_[subtraction_node_id]; @@ -289,11 +267,10 @@ class HistogramBuilder { this->hist_.AllocateAllData(); } - void AddHistRowsDistributed( - int *starting_index, int *sync_count, - std::vector const &nodes_for_explicit_hist_build, - std::vector const &nodes_for_subtraction_trick, - RegTree *p_tree) { + void AddHistRowsDistributed(int *starting_index, int *sync_count, + std::vector const &nodes_for_explicit_hist_build, + std::vector const &nodes_for_subtraction_trick, + RegTree const *p_tree) { const size_t explicit_size = nodes_for_explicit_hist_build.size(); const size_t subtaction_size = nodes_for_subtraction_trick.size(); std::vector merged_node_ids(explicit_size + subtaction_size); @@ -306,7 +283,7 @@ class HistogramBuilder { std::sort(merged_node_ids.begin(), merged_node_ids.end()); int n_left = 0; for (auto const &nid : merged_node_ids) { - if ((*p_tree)[nid].IsLeftChild()) { + if (p_tree->IsLeftChild(nid)) { this->hist_.AddHistRow(nid); (*starting_index) = std::min(nid, (*starting_index)); n_left++; @@ -314,7 +291,7 @@ class HistogramBuilder { } } for (auto const &nid : merged_node_ids) { - if (!((*p_tree)[nid].IsLeftChild())) { + if (!(p_tree->IsLeftChild(nid))) { this->hist_.AddHistRow(nid); this->hist_local_worker_.AddHistRow(nid); } @@ -327,9 +304,9 @@ class HistogramBuilder { // Construct a work space for building histogram. Eventually we should move this // function into histogram builder once hist tree method supports external memory. -template +template common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners, - std::vector const &nodes_to_build) { + std::vector const &nodes_to_build) { std::vector partition_size(nodes_to_build.size(), 0); for (auto const &partition : partitioners) { size_t k = 0; diff --git a/src/tree/hist/sampler.h b/src/tree/hist/sampler.h new file mode 100644 index 000000000000..803e40d547bf --- /dev/null +++ b/src/tree/hist/sampler.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TREE_HIST_SAMPLER_H_ +#define XGBOOST_TREE_HIST_SAMPLER_H_ + +#include // std::size-t +#include // std::uint64_t +#include // bernoulli_distribution, linear_congruential_engine + +#include "../../common/random.h" // GlobalRandom +#include "../param.h" // TrainParam +#include "xgboost/base.h" // GradientPair +#include "xgboost/context.h" // Context +#include "xgboost/data.h" // MetaInfo +#include "xgboost/linalg.h" // TensorView + +namespace xgboost { +namespace tree { +struct RandomReplace { + public: + // similar value as for minstd_rand + static constexpr std::uint64_t kBase = 16807; + static constexpr std::uint64_t kMod = static_cast(1) << 63; + + using EngineT = std::linear_congruential_engine; + + /* + Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation + */ + static std::uint64_t SimpleSkip(std::uint64_t exponent, std::uint64_t initial_seed, + std::uint64_t base, std::uint64_t mod) { + CHECK_LE(exponent, mod); + std::uint64_t result = 1; + while (exponent > 0) { + if (exponent % 2 == 1) { + result = (result * base) % mod; + } + base = (base * base) % mod; + exponent = exponent >> 1; + } + // with result we can now find the new seed + return (result * initial_seed) % mod; + } +}; + +// Only uniform sampling, no gradient-based yet. +inline void SampleGradient(Context const* ctx, TrainParam param, + linalg::MatrixView out) { + CHECK(out.Contiguous()); + CHECK_EQ(param.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported, gradient-based sampling is only support by GPU Hist."; + + if (param.subsample >= 1.0) { + return; + } + bst_row_t n_samples = out.Shape(0); + auto& rnd = common::GlobalRandom(); + +#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG + std::bernoulli_distribution coin_flip(param.subsample); + CHECK_EQ(out.Shape(1), 1) << "Multi-target with sampling for R is not yet supported."; + for (size_t i = 0; i < n_samples; ++i) { + if (!(out(i, 0).GetHess() >= 0.0f && coin_flip(rnd)) || out(i, 0).GetGrad() == 0.0f) { + out(i, 0) = GradientPair(0); + } + } +#else + std::uint64_t initial_seed = rnd(); + + auto n_threads = static_cast(ctx->Threads()); + std::size_t const discard_size = n_samples / n_threads; + std::bernoulli_distribution coin_flip(param.subsample); + + dmlc::OMPException exc; +#pragma omp parallel num_threads(n_threads) + { + exc.Run([&]() { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * discard_size; + const size_t iend = (tid == (n_threads - 1)) ? n_samples : ibegin + discard_size; + + const uint64_t displaced_seed = RandomReplace::SimpleSkip( + ibegin, initial_seed, RandomReplace::kBase, RandomReplace::kMod); + RandomReplace::EngineT eng(displaced_seed); + std::size_t n_targets = out.Shape(1); + if (n_targets > 1) { + for (std::size_t i = ibegin; i < iend; ++i) { + if (!coin_flip(eng)) { + for (std::size_t j = 0; j < n_targets; ++j) { + out(i, j) = GradientPair{}; + } + } + } + } else { + for (std::size_t i = ibegin; i < iend; ++i) { + if (!coin_flip(eng)) { + out(i, 0) = GradientPair{}; + } + } + } + }); + } + exc.Rethrow(); +#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG +} +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_HIST_SAMPLER_H_ diff --git a/src/tree/io_utils.h b/src/tree/io_utils.h new file mode 100644 index 000000000000..db705fa9840c --- /dev/null +++ b/src/tree/io_utils.h @@ -0,0 +1,48 @@ +#ifndef XGBOOST_TREE_IO_UTILS_H_ +#define XGBOOST_TREE_IO_UTILS_H_ +#include // enable_if_t,is_same +#include + +#include "xgboost/json.h" + +namespace xgboost { + +template +using FloatArrayT = std::conditional_t; +template +using U8ArrayT = std::conditional_t; +template +using I32ArrayT = std::conditional_t; +template +using I64ArrayT = std::conditional_t; +template +using IndexArrayT = std::conditional_t, I32ArrayT>; + +// typed array, not boolean +template +std::enable_if_t::value && !std::is_same::value, T> GetElem( + std::vector const& arr, size_t i) { + return arr[i]; +} +// typed array boolean +template +std::enable_if_t::value && std::is_same::value && + std::is_same::value, + bool> +GetElem(std::vector const& arr, size_t i) { + return arr[i] == 1; +} +// json array +template +std::enable_if_t< + std::is_same::value, + std::conditional_t::value, int64_t, + std::conditional_t::value, bool, float>>> +GetElem(std::vector const& arr, size_t i) { + if (std::is_same::value && !IsA(arr[i])) { + return get(arr[i]) == 1; + } + return get(arr[i]); +} +} // namespace xgboost +#endif // XGBOOST_TREE_IO_UTILS_H_ diff --git a/src/tree/multi_target_tree_model.cc b/src/tree/multi_target_tree_model.cc new file mode 100644 index 000000000000..2ca85ccd94f6 --- /dev/null +++ b/src/tree/multi_target_tree_model.cc @@ -0,0 +1,188 @@ +/** + * Copyright 2022 by XGBoost Contributors + */ +#include "xgboost/multi_target_tree_model.h" + +#include // copy +#include +#include +#include + +#include "io_utils.h" +#include "xgboost/base.h" +#include "xgboost/json.h" +#include "xgboost/logging.h" +#include "xgboost/string_view.h" + +namespace xgboost { +template +void LoadModelImpl(Json const& in, std::vector* p_weights, std::vector* p_lefts, + std::vector* p_rights, std::vector* p_parents, + std::vector* p_conds, std::vector* p_fidx, + std::vector* p_dft_left) { + auto get_float = [&](StringView name, std::vector* p_out) { + auto& values = get>(get(in).find(name)->second); + auto& out = *p_out; + out.resize(values.size()); + for (std::size_t i = 0; i < values.size(); ++i) { + out[i] = GetElem(values, i); + } + }; + get_float("weights", p_weights); + get_float("split_conditions", p_conds); + + auto get_nidx = [&](StringView name, std::vector* p_nidx) { + auto& nidx = get>(get(in).find(name)->second); + auto& out_nidx = *p_nidx; + out_nidx.resize(nidx.size()); + for (std::size_t i = 0; i < nidx.size(); ++i) { + out_nidx[i] = GetElem(nidx, i); + } + }; + get_nidx("left_children", p_lefts); + get_nidx("right_children", p_rights); + get_nidx("parents", p_parents); + + auto const& splits = get const>(in["split_indices"]); + p_fidx->resize(splits.size()); + auto& out_fidx = *p_fidx; + for (std::size_t i = 0; i < splits.size(); ++i) { + out_fidx[i] = GetElem(splits, i); + } + + auto const& dft_left = get const>(in["default_left"]); + auto& out_dft_l = *p_dft_left; + out_dft_l.resize(dft_left.size()); + for (std::size_t i = 0; i < dft_left.size(); ++i) { + out_dft_l[i] = GetElem(dft_left, i); + } +} + +void MultiTargetTree::LoadModel(Json const& in) { + bool typed = IsA(in["weights"]); + bool feature_is_64 = IsA(in["split_indices"]); + + if (typed && feature_is_64) { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } else if (typed && !feature_is_64) { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } else if (!typed && feature_is_64) { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } else { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } + + this->n_nodes_ = weights_.size(); +} + +void MultiTargetTree::SaveModel(Json* p_out) const { + CHECK(p_out); + auto& out = *p_out; + + // nodes + I32Array lefts(n_nodes_); + I32Array rights(n_nodes_); + I32Array parents(n_nodes_); + F32Array conds(n_nodes_); + U8Array default_left(n_nodes_); + F32Array weights(n_nodes_ * n_targets_); + + auto save_tree = [&](auto* p_indices_array) { + auto& indices_array = *p_indices_array; + for (bst_node_t nidx = 0; nidx < n_nodes_; ++nidx) { + CHECK_LT(nidx, left_.size()); + lefts.Set(nidx, left_[nidx]); + CHECK_LT(nidx, right_.size()); + rights.Set(nidx, right_[nidx]); + CHECK_LT(nidx, parent_.size()); + parents.Set(nidx, parent_[nidx]); + CHECK_LT(nidx, split_index_.size()); + indices_array.Set(nidx, split_index_[nidx]); + conds.Set(nidx, split_conds_[nidx]); + default_left.Set(nidx, default_left_[nidx]); + + // fixme: unify the code with expand + auto weight_view = common::Span{weights_}; + auto p_weight = weight_view.subspan(nidx * n_targets_, n_targets_); + auto weight_out = + common::Span(weights.GetArray()).subspan(nidx * n_targets_, n_targets_); + CHECK_EQ(p_weight.size(), weight_out.size()); + std::copy_n(p_weight.data(), p_weight.size(), weight_out.data()); + } + }; + + if (this->n_features_ > static_cast(std::numeric_limits::max())) { + I64Array indices_64(n_nodes_); + save_tree(&indices_64); + out["split_indices"] = std::move(indices_64); + } else { + I32Array indices_32(n_nodes_); + save_tree(&indices_32); + out["split_indices"] = std::move(indices_32); + } + + out["weights"] = std::move(weights); + out["left_children"] = std::move(lefts); + out["right_children"] = std::move(rights); + out["parents"] = std::move(parents); + + out["split_conditions"] = std::move(conds); + out["default_left"] = std::move(default_left); +} + +void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, + bool default_left, linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight) { + CHECK_GE(parent_.size(), 1); + CHECK_EQ(parent_.size(), left_.size()); + CHECK_EQ(left_.size(), right_.size()); + + size_t n = n_nodes_ + 2; + left_.resize(n, InvalidNodeId()); + right_.resize(n, InvalidNodeId()); + parent_.resize(n, InvalidNodeId()); + + auto left_child = parent_.size() - 2; + auto right_child = parent_.size() - 1; + + left_[nidx] = left_child; + right_[nidx] = right_child; + + if (nidx != 0) { + CHECK_NE(parent_[nidx], InvalidNodeId()); + } + + parent_[left_child] = nidx; + parent_[right_child] = nidx; + + split_index_.resize(n); + split_index_[nidx] = split_idx; + + split_conds_.resize(n); + split_conds_[nidx] = split_cond; + default_left_.resize(n); + default_left_[nidx] = static_cast(default_left); + + weights_.resize(n * n_targets_); + auto weight_view = common::Span{weights_}; + auto p_weight = weight_view.subspan(nidx * n_targets_, n_targets_); + CHECK_EQ(p_weight.size(), base_weight.Size()); + auto l_weight = weight_view.subspan(left_child * n_targets_, n_targets_); + CHECK_EQ(l_weight.size(), left_weight.Size()); + auto r_weight = weight_view.subspan(right_child * n_targets_, n_targets_); + CHECK_EQ(r_weight.size(), right_weight.Size()); + + for (size_t i = 0; i < base_weight.Size(); ++i) { + p_weight[i] = base_weight(i); + l_weight[i] = left_weight(i); + r_weight[i] = right_weight(i); + } + + n_nodes_ = n; +} +} // namespace xgboost diff --git a/src/tree/param.h b/src/tree/param.h index 3f5e4ec7bc71..14a9b8173b7a 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2021 by Contributors + * Copyright 2014-2022 by Contributors * \file param.h * \brief training parameters, statistics used to support tree construction. * \author Tianqi Chen @@ -14,10 +14,12 @@ #include #include -#include "xgboost/parameter.h" -#include "xgboost/data.h" #include "../common/categorical.h" +#include "../common/linalg_op.h" #include "../common/math.h" +#include "xgboost/data.h" +#include "xgboost/linalg.h" +#include "xgboost/parameter.h" namespace xgboost { namespace tree { @@ -293,6 +295,23 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess()); } +inline void CalcWeight(TrainParam const &p, linalg::VectorView grad_sum, + linalg::VectorView out_w) { + for (bst_target_t i = 0; i < out_w.Size(); ++i) { + out_w(i) = CalcWeight(p, grad_sum(i).GetGrad(), grad_sum(i).GetHess()); + } +} + +inline double CalcGainGivenWeight(TrainParam const &p, + linalg::VectorView sum_grad, + linalg::VectorView weight) { + double gain{0}; + for (bst_target_t i = 0; i < weight.Size(); ++i) { + gain += -weight(i) * ThresholdL1(sum_grad(i).GetGrad(), p.reg_alpha); + } + return gain; +} + /*! \brief core statistics used for tree construction */ struct XGBOOST_ALIGNAS(16) GradStats { using GradType = double; @@ -349,6 +368,19 @@ struct XGBOOST_ALIGNAS(16) GradStats { } }; +// Helper functions for copying gradient statistic, one for vector leaf, another for normal scalar. +template +std::vector &CopyStats(linalg::VectorView const &src, std::vector *dst) { // NOLINT + dst->resize(src.Size()); + std::copy(linalg::cbegin(src), linalg::cend(src), dst->begin()); + return *dst; +} + +inline GradStats &CopyStats(GradStats const &src, GradStats *dst) { // NOLINT + *dst = src; + return *dst; +} + /*! * \brief statistics that is helpful to store * and represent a split solution for the tree @@ -430,9 +462,10 @@ struct SplitEntryContainer { * \param default_left whether the missing value goes to left * \return whether the proposed split is better and can replace current split */ - bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool default_left, bool is_cat, - const GradientT &left_sum, const GradientT &right_sum) { + template + bool Update(bst_float new_loss_chg, unsigned split_index, bst_float new_split_value, + bool default_left, bool is_cat, GradientSumT const &left_sum, + GradientSumT const &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; if (default_left) { @@ -441,8 +474,8 @@ struct SplitEntryContainer { this->sindex = split_index; this->split_value = new_split_value; this->is_cat = is_cat; - this->left_sum = left_sum; - this->right_sum = right_sum; + CopyStats(left_sum, &this->left_sum); + CopyStats(right_sum, &this->right_sum); return true; } else { return false; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 4bd2294d1522..48a974302936 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -3,23 +3,25 @@ * \file tree_model.cc * \brief model structure for tree */ -#include #include - -#include -#include +#include #include +#include -#include -#include #include #include -#include +#include +#include +#include -#include "param.h" -#include "../common/common.h" #include "../common/categorical.h" +#include "../common/common.h" #include "../predictor/predict_fn.h" +#include "io_utils.h" // GetElem +#include "param.h" +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/logging.h" namespace xgboost { // register tree parameter @@ -807,6 +809,36 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v this->split_types_.at(nid) = FeatureType::kNumerical; } +void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, + bool default_left, linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight) { + CHECK(this->p_mt_tree_); + this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight, + right_weight); + + split_types_.resize(this->Size(), FeatureType::kNumerical); + split_categories_segments_.resize(this->Size()); + this->split_types_.at(nidx) = FeatureType::kNumerical; +} + +void RegTree::ExpandCategorical(bst_node_t nidx, bst_feature_t split_index, + common::Span split_cat, bool default_left, + linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight) { + this->ExpandNode(nidx, split_index, std::numeric_limits::quiet_NaN(), default_left, + base_weight, left_weight, right_weight); + + size_t orig_size = split_categories_.size(); + this->split_categories_.resize(orig_size + split_cat.size()); + std::copy(split_cat.data(), split_cat.data() + split_cat.size(), + split_categories_.begin() + orig_size); + this->split_types_.at(nidx) = FeatureType::kCategorical; + this->split_categories_segments_.at(nidx).beg = orig_size; + this->split_categories_segments_.at(nidx).size = split_cat.size(); +} + void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, @@ -893,44 +925,17 @@ void RegTree::Save(dmlc::Stream* fo) const { } } } -// typed array, not boolean -template -std::enable_if_t::value && !std::is_same::value, T> GetElem( - std::vector const& arr, size_t i) { - return arr[i]; -} -// typed array boolean -template -std::enable_if_t::value && std::is_same::value && - std::is_same::value, - bool> -GetElem(std::vector const& arr, size_t i) { - return arr[i] == 1; -} -// json array -template -std::enable_if_t< - std::is_same::value, - std::conditional_t::value, int64_t, - std::conditional_t::value, bool, float>>> -GetElem(std::vector const& arr, size_t i) { - if (std::is_same::value && !IsA(arr[i])) { - return get(arr[i]) == 1; - } - return get(arr[i]); -} template void RegTree::LoadCategoricalSplit(Json const& in) { - using I64ArrayT = std::conditional_t; - using I32ArrayT = std::conditional_t; - - auto const& categories_segments = get(in["categories_segments"]); - auto const& categories_sizes = get(in["categories_sizes"]); - auto const& categories_nodes = get(in["categories_nodes"]); - auto const& categories = get(in["categories"]); - - size_t cnt = 0; + auto const& categories_segments = get>(in["categories_segments"]); + auto const& categories_sizes = get>(in["categories_sizes"]); + auto const& categories_nodes = get>(in["categories_nodes"]); + auto const& categories = get>(in["categories"]); + + auto split_type = get>(in["split_type"]); + bst_node_t n_nodes = split_type.size(); + std::size_t cnt = 0; bst_node_t last_cat_node = -1; if (!categories_nodes.empty()) { last_cat_node = GetElem(categories_nodes, cnt); @@ -938,7 +943,10 @@ void RegTree::LoadCategoricalSplit(Json const& in) { // `categories_segments' is only available for categorical nodes to prevent overhead for // numerical node. As a result, we need to track the categorical nodes we have processed // so far. - for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) { + split_types_.resize(n_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(n_nodes); + for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) { + split_types_[nidx] = static_cast(GetElem(split_type, nidx)); if (nidx == last_cat_node) { auto j_begin = GetElem(categories_segments, cnt); auto j_end = GetElem(categories_sizes, cnt) + j_begin; @@ -985,15 +993,17 @@ template void RegTree::LoadCategoricalSplit(Json const& in); void RegTree::SaveCategoricalSplit(Json* p_out) const { auto& out = *p_out; - CHECK_EQ(this->split_types_.size(), param.num_nodes); - CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); + CHECK_EQ(this->split_types_.size(), this->Size()); + CHECK_EQ(this->GetSplitCategoriesPtr().size(), this->Size()); I64Array categories_segments; I64Array categories_sizes; I32Array categories; // bst_cat_t = int32_t I32Array categories_nodes; // bst_note_t = int32_t + U8Array split_type(split_types_.size()); for (size_t i = 0; i < nodes_.size(); ++i) { + split_type.Set(i, static_cast>(this->NodeSplitType(i))); if (this->split_types_[i] == FeatureType::kCategorical) { categories_nodes.GetArray().emplace_back(i); auto begin = categories.Size(); @@ -1012,66 +1022,49 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const { } } + out["split_type"] = std::move(split_type); out["categories_segments"] = std::move(categories_segments); out["categories_sizes"] = std::move(categories_sizes); out["categories_nodes"] = std::move(categories_nodes); out["categories"] = std::move(categories); } -template , - typename U8ArrayT = std::conditional_t, - typename I32ArrayT = std::conditional_t, - typename I64ArrayT = std::conditional_t, - typename IndexArrayT = std::conditional_t> -bool LoadModelImpl(Json const& in, TreeParam* param, std::vector* p_stats, - std::vector* p_split_types, std::vector* p_nodes, - std::vector* p_split_categories_segments) { +template +void LoadModelImpl(Json const& in, TreeParam* param, std::vector* p_stats, + std::vector* p_nodes) { auto& stats = *p_stats; - auto& split_types = *p_split_types; auto& nodes = *p_nodes; - auto& split_categories_segments = *p_split_categories_segments; FromJson(in["tree_param"], param); auto n_nodes = param->num_nodes; CHECK_NE(n_nodes, 0); // stats - auto const& loss_changes = get(in["loss_changes"]); + auto const& loss_changes = get>(in["loss_changes"]); CHECK_EQ(loss_changes.size(), n_nodes); - auto const& sum_hessian = get(in["sum_hessian"]); + auto const& sum_hessian = get>(in["sum_hessian"]); CHECK_EQ(sum_hessian.size(), n_nodes); - auto const& base_weights = get(in["base_weights"]); + auto const& base_weights = get>(in["base_weights"]); CHECK_EQ(base_weights.size(), n_nodes); // nodes - auto const& lefts = get(in["left_children"]); + auto const& lefts = get>(in["left_children"]); CHECK_EQ(lefts.size(), n_nodes); - auto const& rights = get(in["right_children"]); + auto const& rights = get>(in["right_children"]); CHECK_EQ(rights.size(), n_nodes); - auto const& parents = get(in["parents"]); + auto const& parents = get>(in["parents"]); CHECK_EQ(parents.size(), n_nodes); - auto const& indices = get(in["split_indices"]); + auto const& indices = get>(in["split_indices"]); CHECK_EQ(indices.size(), n_nodes); - auto const& conds = get(in["split_conditions"]); + auto const& conds = get>(in["split_conditions"]); CHECK_EQ(conds.size(), n_nodes); - auto const& default_left = get(in["default_left"]); + auto const& default_left = get>(in["default_left"]); CHECK_EQ(default_left.size(), n_nodes); - bool has_cat = get(in).find("split_type") != get(in).cend(); - std::remove_const_t(in["split_type"]))>> - split_type; - if (has_cat) { - split_type = get(in["split_type"]); - } - // Initialization stats = std::remove_reference_t(n_nodes); nodes = std::remove_reference_t(n_nodes); - split_types = std::remove_reference_t(n_nodes); - split_categories_segments = std::remove_reference_t(n_nodes); static_assert(std::is_integral(lefts, 0))>::value, ""); static_assert(std::is_floating_point(loss_changes, 0))>::value, ""); - CHECK_EQ(n_nodes, split_categories_segments.size()); // Set node for (int32_t i = 0; i < n_nodes; ++i) { @@ -1088,41 +1081,42 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector* float cond{GetElem(conds, i)}; bool dft_left{GetElem(default_left, i)}; n = RegTree::Node{left, right, parent, ind, cond, dft_left}; - - if (has_cat) { - split_types[i] = static_cast(GetElem(split_type, i)); - } } - - return has_cat; } void RegTree::LoadModel(Json const& in) { - bool has_cat{false}; - bool typed = IsA(in["loss_changes"]); - bool feature_is_64 = IsA(in["split_indices"]); - if (typed && feature_is_64) { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } else if (typed && !feature_is_64) { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } else if (!typed && feature_is_64) { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } else { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } + bool typed = IsA(in["parents"]); + auto const& in_obj = get(in); + bool has_cat = in_obj.find("split_type") != in_obj.cend(); if (has_cat) { if (typed) { this->LoadCategoricalSplit(in); } else { this->LoadCategoricalSplit(in); } + } + + auto is_multi = in_obj.find("weights") != in_obj.cend(); + if (is_multi) { + this->GetMultiTargetTree()->LoadModel(in); + return; + } + + bool feature_is_64 = IsA(in["split_indices"]); + if (typed && feature_is_64) { + LoadModelImpl(in, ¶m, &stats_, &nodes_); + } else if (typed && !feature_is_64) { + LoadModelImpl(in, ¶m, &stats_, &nodes_); + } else if (!typed && feature_is_64) { + LoadModelImpl(in, ¶m, &stats_, &nodes_); } else { + LoadModelImpl(in, ¶m, &stats_, &nodes_); + } + + if (!has_cat) { this->split_categories_segments_.resize(this->param.num_nodes); + this->split_types_.resize(this->param.num_nodes); std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); } @@ -1144,6 +1138,12 @@ void RegTree::LoadModel(Json const& in) { } void RegTree::SaveModel(Json* p_out) const { + this->SaveCategoricalSplit(p_out); + if (this->IsMultiTarget()) { + this->GetMultiTargetTree()->SaveModel(p_out); + return; + } + /* Here we are treating leaf node and internal node equally. Some information like * child node id doesn't make sense for leaf node but we will have to save them to * avoid creating a huge map. One difficulty is XGBoost has deleted node created by @@ -1170,7 +1170,6 @@ void RegTree::SaveModel(Json* p_out) const { F32Array conds(n_nodes); U8Array default_left(n_nodes); - U8Array split_type(n_nodes); CHECK_EQ(this->split_types_.size(), param.num_nodes); auto save_tree = [&](auto* p_indices_array) { @@ -1188,8 +1187,6 @@ void RegTree::SaveModel(Json* p_out) const { indices_array.Set(i, n.SplitIndex()); conds.Set(i, n.SplitCond()); default_left.Set(i, static_cast(!!n.DefaultLeft())); - - split_type.Set(i, static_cast(this->NodeSplitType(i))); } }; if (this->param.num_feature > static_cast(std::numeric_limits::max())) { @@ -1202,9 +1199,6 @@ void RegTree::SaveModel(Json* p_out) const { out["split_indices"] = std::move(indices_32); } - this->SaveCategoricalSplit(&out); - - out["split_type"] = std::move(split_type); out["loss_changes"] = std::move(loss_changes); out["sum_hessian"] = std::move(sum_hessian); out["base_weights"] = std::move(base_weights); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index bc090ed3fdba..a289dedf5e80 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 XGBoost contributors +/** + * Copyright 2021-2023 by XGBoost contributors * * \brief Implementation for the approx tree method. */ @@ -14,9 +14,12 @@ #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" +#include "hist/sampler.h" // SampleGradient #include "param.h" #include "xgboost/base.h" +#include "xgboost/data.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" @@ -58,7 +61,7 @@ class GloablApproxBuilder { monitor_->Start(__func__); n_batches_ = 0; - int32_t n_total_bins = 0; + bst_bin_t n_total_bins = 0; partitioner_.clear(); // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess, task_))) { @@ -220,8 +223,8 @@ class GloablApproxBuilder { for (auto const &candidate : valid_candidates) { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx), {}}; - CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx), {}}; + CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx)}; + CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx)}; best_splits.push_back(l_best); best_splits.push_back(r_best); } @@ -256,8 +259,7 @@ class GlobalApproxUpdater : public TreeUpdater { ObjInfo task_; public: - explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) - : TreeUpdater(ctx), task_{task} { + explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} { monitor_.Init(__func__); } @@ -272,24 +274,11 @@ class GlobalApproxUpdater : public TreeUpdater { } void InitData(TrainParam const ¶m, HostDeviceVector const *gpair, - std::vector *sampled) { - auto const &h_gpair = gpair->ConstHostVector(); - sampled->resize(h_gpair.size()); - std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); - auto &rnd = common::GlobalRandom(); - - if (param.subsample != 1.0) { - CHECK(param.sampling_method != TrainParam::kGradientBased) - << "Gradient based sampling is not supported for approx tree method."; - std::bernoulli_distribution coin_flip(param.subsample); - std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) { - if (coin_flip(rnd)) { - return g; - } else { - return GradientPair{}; - } - }); - } + linalg::Matrix *sampled) { + *sampled = linalg::Empty(ctx_, gpair->Size(), 1); + sampled->Data()->Copy(*gpair); + + SampleGradient(ctx_, param, sampled->HostView()); } char const *Name() const override { return "grow_histmaker"; } @@ -303,18 +292,19 @@ class GlobalApproxUpdater : public TreeUpdater { pimpl_ = std::make_unique(param_, m->Info(), ctx_, column_sampler_, task_, &monitor_); - std::vector h_gpair; - InitData(param_, gpair, &h_gpair); + linalg::Matrix h_gpair; // Obtain the hessian values for weighted sketching - std::vector hess(h_gpair.size()); - std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(), + InitData(param_, gpair, &h_gpair); + std::vector hess(h_gpair.Size()); + auto const &s_gpair = h_gpair.Data()->ConstHostVector(); + std::transform(s_gpair.begin(), s_gpair.end(), hess.begin(), [](auto g) { return g.GetHess(); }); cached_ = m; size_t t_idx = 0; for (auto p_tree : trees) { - this->pimpl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); + this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]); ++t_idx; } param_.learning_rate = lr; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 85371672639c..ca37bad3307c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2022 XGBoost contributors + * Copyright 2017-2023 XGBoost contributors */ #include #include @@ -160,11 +160,11 @@ class DeviceHistogramStorage { if (nidx_map_.find(nidx) != nidx_map_.cend()) { // Fetch from normal cache auto ptr = data_.data().get() + nidx_map_.at(nidx); - return common::Span(reinterpret_cast(ptr), n_bins_); + return {reinterpret_cast(ptr), static_cast(n_bins_)}; } else { // Fetch from overflow auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx); - return common::Span(reinterpret_cast(ptr), n_bins_); + return {reinterpret_cast(ptr), static_cast(n_bins_)}; } } }; @@ -437,7 +437,7 @@ struct GPUHistMakerDevice { dh::caching_device_vector d_split_types; dh::caching_device_vector d_categories; - dh::caching_device_vector d_categories_segments; + dh::caching_device_vector d_categories_segments; if (!categories.empty()) { dh::CopyToD(h_split_types, &d_split_types); @@ -454,7 +454,7 @@ struct GPUHistMakerDevice { const common::Span d_nodes, common::Span d_feature_types, common::Span categories, - common::Span categories_segments, + common::Span categories_segments, HostDeviceVector* p_out_position) { auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_gpair = this->gpair; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 525376730fec..ff4f5e603b59 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -6,304 +6,347 @@ */ #include "./updater_quantile_hist.h" -#include -#include -#include +#include // copy +#include // size_t +#include // uint32_t,int32_t,uint64_t +#include // shared_ptr, unique_ptr +#include +#include // string #include -#include - -#include "common_row_partitioner.h" -#include "constraints.h" -#include "hist/histogram.h" -#include "hist/evaluate_splits.h" -#include "param.h" +#include // vector + +#include "../collective/communicator-inl.h" // IsDistributed,Allreduce +#include "../common/hist_util.h" // HistogramCuts,GHistRow +#include "../common/linalg_op.h" // cbegin,cend,begin +#include "../common/random.h" // ColumnSampler +#include "../common/threading_utils.h" // BlockedSpace2d,ParallelFor2d +#include "../common/timer.h" // Monitor +#include "../data/gradient_index.h" // GHistIndexMatrix +#include "common_row_partitioner.h" // CommonRowPartitioner +#include "dmlc/omp.h" // omp_get_thread_num +#include "driver.h" // Driver +#include "hist/evaluate_splits.h" // HistEvaluator,HistMultiEvaluator +#include "hist/expand_entry.h" // MultiExpandEntry +#include "hist/histogram.h" // HistogramBuilder,ConstructHistSpace +#include "hist/sampler.h" // CPUSampleGradient +#include "param.h" // TrainParam, CalcWeight, CalcGainGivenWeight, SplitEntryContainer +#include "xgboost/base.h" // GradientPair,GradientPairPrecise,bst_node_t,bst_feature_t,bst_target_t +#include "xgboost/context.h" // Context +#include "xgboost/data.h" // DMatrix,MetaInfo +#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "xgboost/json.h" // Json +#include "xgboost/linalg.h" // VectorView,TensorView,Tensor #include "xgboost/logging.h" -#include "xgboost/tree_updater.h" +#include "xgboost/span.h" // Span +#include "xgboost/task.h" // ObjInfo +#include "xgboost/tree_model.h" // RegTree, MTNotImplemented +#include "xgboost/tree_updater.h" // TreeUpdater namespace xgboost { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_quantile_hist); -void QuantileHistMaker::Configure(const Args &args) { - param_.UpdateAllowUnknown(args); -} - -void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, - common::Span> out_position, - const std::vector &trees) { - // rescale learning rate according to size of trees - float lr = param_.learning_rate; - param_.learning_rate = lr / trees.size(); - - // build tree - const size_t n_trees = trees.size(); - if (!pimpl_) { - pimpl_.reset(new Builder(n_trees, param_, dmat, task_, ctx_)); +template +class UpdateTreeMixIn { + protected: + TrainParam param_; + common::Monitor *monitor_; + + private: + auto *Self() { return static_cast(this); } + + public: + explicit UpdateTreeMixIn(TrainParam param, common::Monitor *monitor) + : param_{std::move(param)}, monitor_{monitor} {} + void UpdateTree(linalg::MatrixView gpair, DMatrix *p_fmat, RegTree *p_tree, + HostDeviceVector *p_out_position) { + monitor_->Start(__func__); + Self()->InitData(p_fmat, p_tree); + + Driver driver{this->param_}; + auto const &tree = *p_tree; + driver.Push(Self()->InitRoot(p_fmat, gpair, p_tree)); + auto expand_set = driver.Pop(); + + /** + * Note for update position + * Root: + * Not applied: No need to update position as initialization has got all the rows ordered. + * Applied: Update position is run on applied nodes so the rows are partitioned. + * Non-root: + * Not applied: That node is root of the subtree, same rule as root. + * Applied: Ditto + */ + while (!expand_set.empty()) { + // candidates that can be further splited. + std::vector valid_candidates; + // candidaates that can be applied. + std::vector applied; + for (auto const &candidate : expand_set) { + Self()->ApplyTreeSplit(candidate, p_tree); + CHECK_GT(p_tree->LeftChild(candidate.nid), candidate.nid); + applied.push_back(candidate); + if (driver.IsChildValid(candidate)) { + valid_candidates.emplace_back(candidate); + } + } + + Self()->UpdatePosition(p_fmat, p_tree, applied); + + std::vector best_splits; + if (!valid_candidates.empty()) { + Self()->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair); + for (auto const &candidate : valid_candidates) { + auto left_child_nidx = tree.LeftChild(candidate.nid); + auto right_child_nidx = tree.RightChild(candidate.nid); + ExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx)}; + ExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx)}; + best_splits.push_back(l_best); + best_splits.push_back(r_best); + } + Self()->EvaluateSplits(p_fmat, p_tree, &best_splits); + } + driver.Push(best_splits.begin(), best_splits.end()); + expand_set = driver.Pop(); + } + + auto &h_out_position = p_out_position->HostVector(); + Self()->LeafPartition(tree, gpair, &h_out_position); + monitor_->Stop(__func__); + } +}; + +class MultiTargetHistBuilder : public UpdateTreeMixIn { + protected: + std::shared_ptr col_sampler_; + HistMultiEvaluator evaluator_; + std::vector> histogram_builder_; + Context const *ctx_; + + std::vector partitioner_; + // Pointer to last updated tree, used for update prediction cache. + RegTree const *p_last_tree_{nullptr}; + + ObjInfo task_; + + public: + void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &applied) { + monitor_->Start(__func__); + std::size_t page_id{0}; + for (auto const &page : p_fmat->GetBatches(HistBatch(this->param_))) { + this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree); + page_id++; + } + monitor_->Stop(__func__); } - size_t t_idx{0}; - for (auto p_tree : trees) { - auto &t_row_position = out_position[t_idx]; - this->pimpl_->UpdateTree(gpair, dmat, p_tree, &t_row_position); - ++t_idx; + void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) { + this->evaluator_.ApplyTreeSplit(candidate, p_tree); } - param_.learning_rate = lr; -} + void InitData(DMatrix *p_fmat, RegTree const *p_tree) { + monitor_->Start(__func__); -bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data, - linalg::VectorView out_preds) { - if (pimpl_) { - return pimpl_->UpdatePredictionCache(data, out_preds); - } else { - return false; - } -} - -CPUExpandEntry QuantileHistMaker::Builder::InitRoot( - DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { - CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); - - size_t page_id = 0; - auto space = ConstructHistSpace(partitioner_, {node}); - for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { - std::vector nodes_to_build{node}; - std::vector nodes_to_sub; - this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree, - partitioner_.at(page_id).Partitions(), nodes_to_build, - nodes_to_sub, gpair_h); - ++page_id; + std::size_t page_id = 0; + bst_bin_t n_total_bins = 0; + partitioner_.clear(); + for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { + if (n_total_bins == 0) { + n_total_bins = page.cut.TotalBins(); + } else { + CHECK_EQ(n_total_bins, page.cut.TotalBins()); + } + partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid); + page_id++; + } + + bst_target_t n_targets = p_tree->NumTargets(); + histogram_builder_.clear(); + for (std::size_t i = 0; i < n_targets; ++i) { + histogram_builder_.emplace_back(); + histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, + collective::IsDistributed()); + } + + p_last_tree_ = p_tree; + monitor_->Stop(__func__); } - { - GradientPairPrecise grad_stat; - if (p_fmat->IsDense()) { - /** - * Specialized code for dense data: For dense data (with no missing value), the sum - * of gradient histogram is equal to snode[nid] - */ - auto const &gmat = *(p_fmat->GetBatches(HistBatch(param_)).begin()); - std::vector const &row_ptr = gmat.cut.Ptrs(); - CHECK_GE(row_ptr.size(), 2); - uint32_t const ibegin = row_ptr[0]; - uint32_t const iend = row_ptr[1]; - auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot]; - auto begin = hist.data(); - for (uint32_t i = ibegin; i < iend; ++i) { - GradientPairPrecise const &et = begin[i]; - grad_stat.Add(et.GetGrad(), et.GetHess()); + MultiExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView gpair, + RegTree *p_tree) { + monitor_->Start(__func__); + MultiExpandEntry best; + best.nid = RegTree::kRoot; + best.depth = 0; + + // FIXME(jiamingy): Extract this part as fitstump after init estimation is merged. + auto n_targets = p_tree->NumTargets(); + linalg::Matrix root_sum_tloc = + linalg::Empty(ctx_, ctx_->Threads(), n_targets); + CHECK_EQ(root_sum_tloc.Shape(1), gpair.Shape(1)); + auto h_root_sum_tloc = root_sum_tloc.HostView(); + common::ParallelFor(gpair.Shape(0), ctx_->Threads(), [&](auto i) { + for (bst_target_t t{0}; t < n_targets; ++t) { + h_root_sum_tloc(omp_get_thread_num(), t) += GradientPairPrecise{gpair(i, t)}; } - } else { - for (auto const &grad : gpair_h) { - grad_stat.Add(grad.GetGrad(), grad.GetHess()); + }); + // Aggregate to the first row. + auto root_sum = h_root_sum_tloc.Slice(0, linalg::All()); + for (std::int32_t tidx{1}; tidx < ctx_->Threads(); ++tidx) { + for (bst_target_t t{0}; t < n_targets; ++t) { + root_sum(t) += h_root_sum_tloc(tidx, t); } - collective::Allreduce(reinterpret_cast(&grad_stat), 2); } + CHECK(root_sum.CContiguous()); + collective::Allreduce( + reinterpret_cast(root_sum.Values().data()), root_sum.Size() * 2); - auto weight = evaluator_->InitRoot(GradStats{grad_stat}); - p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess(); - p_tree->Stat(RegTree::kRoot).base_weight = weight; - (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); + std::vector nodes{best}; + std::size_t i = 0; + auto space = ConstructHistSpace(partitioner_, nodes); + for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { + for (bst_target_t t{0}; t < n_targets; ++t) { + auto t_gpair = gpair.Slice(linalg::All(), t); + histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), + nodes, {}, t_gpair.Values()); + } + i++; + } - std::vector entries{node}; - monitor_->Start("EvaluateSplits"); - auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + auto weight = evaluator_.InitRoot(root_sum); + auto weight_t = weight.HostView(); + std::transform(linalg::cbegin(weight_t), linalg::cend(weight_t), linalg::begin(weight_t), + [&](float w) { return w * param_.learning_rate; }); + + p_tree->SetLeaf(RegTree::kRoot, weight_t); + std::vector hists; + for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) { + hists.push_back(&histogram_builder_[t].Histogram()); + } for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries); + evaluator_.EvaluateSplits(*p_tree, hists, gmat.cut, &nodes); break; } - monitor_->Stop("EvaluateSplits"); - node = entries.front(); - } + monitor_->Stop(__func__); - return node; -} - -void QuantileHistMaker::Builder::BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, - std::vector const &valid_candidates, - std::vector const &gpair) { - std::vector nodes_to_build(valid_candidates.size()); - std::vector nodes_to_sub(valid_candidates.size()); - - size_t n_idx = 0; - for (auto const &c : valid_candidates) { - auto left_nidx = (*p_tree)[c.nid].LeftChild(); - auto right_nidx = (*p_tree)[c.nid].RightChild(); - auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); - - auto build_nidx = left_nidx; - auto subtract_nidx = right_nidx; - if (fewer_right) { - std::swap(build_nidx, subtract_nidx); - } - nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}; - nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}; - n_idx++; + return nodes.front(); } - size_t page_id{0}; - auto space = ConstructHistSpace(partitioner_, nodes_to_build); - for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { - histogram_builder_->BuildHist(page_id, space, gidx, p_tree, - partitioner_.at(page_id).Partitions(), nodes_to_build, - nodes_to_sub, gpair); - ++page_id; - } -} - -void QuantileHistMaker::Builder::LeafPartition(RegTree const &tree, - common::Span gpair, - std::vector *p_out_position) { - monitor_->Start(__func__); - if (!task_.UpdateTreeLeaf()) { - return; - } - for (auto const &part : partitioner_) { - part.LeafPartition(ctx_, tree, gpair, p_out_position); - } - monitor_->Stop(__func__); -} - -void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree, - const std::vector &gpair_h, - HostDeviceVector *p_out_position) { - monitor_->Start(__func__); - - Driver driver(param_); - driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); - auto const &tree = *p_tree; - auto expand_set = driver.Pop(); - - while (!expand_set.empty()) { - // candidates that can be further splited. - std::vector valid_candidates; - // candidaates that can be applied. - std::vector applied; - int32_t depth = expand_set.front().depth + 1; - for (auto const& candidate : expand_set) { - evaluator_->ApplyTreeSplit(candidate, p_tree); - applied.push_back(candidate); - if (driver.IsChildValid(candidate)) { - valid_candidates.emplace_back(candidate); + void BuildHistogram(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &valid_candidates, + linalg::MatrixView gpair) { + monitor_->Start(__func__); + std::vector nodes_to_build; + std::vector nodes_to_sub; + + for (auto const &c : valid_candidates) { + auto left_nidx = p_tree->LeftChild(c.nid); + auto right_nidx = p_tree->RightChild(c.nid); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + auto lit = + common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); }); + auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0); + auto rit = + common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); }); + auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0); + auto fewer_right = right_sum < left_sum; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); } + nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx)); + nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx)); } - monitor_->Start("UpdatePosition"); - size_t page_id{0}; + std::size_t i = 0; + auto space = ConstructHistSpace(partitioner_, nodes_to_build); for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { - partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree); - ++page_id; - } - monitor_->Stop("UpdatePosition"); - - std::vector best_splits; - if (!valid_candidates.empty()) { - this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h); - for (auto const &candidate : valid_candidates) { - int left_child_nidx = tree[candidate.nid].LeftChild(); - int right_child_nidx = tree[candidate.nid].RightChild(); - CPUExpandEntry l_best{left_child_nidx, depth, 0.0}; - CPUExpandEntry r_best{right_child_nidx, depth, 0.0}; - best_splits.push_back(l_best); - best_splits.push_back(r_best); - } - auto const &histograms = histogram_builder_->Histogram(); - auto ft = p_fmat->Info().feature_types.ConstHostSpan(); - for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { - evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, &best_splits); - break; + for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) { + auto t_gpair = gpair.Slice(linalg::All(), t); + CHECK(t_gpair.Contiguous()); + histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), + nodes_to_build, nodes_to_sub, t_gpair.Values()); } + i++; } - driver.Push(best_splits.begin(), best_splits.end()); - expand_set = driver.Pop(); + monitor_->Stop(__func__); } - auto &h_out_position = p_out_position->HostVector(); - this->LeafPartition(tree, gpair_h, &h_out_position); - monitor_->Stop(__func__); -} - -void QuantileHistMaker::Builder::UpdateTree(HostDeviceVector *gpair, DMatrix *p_fmat, - RegTree *p_tree, - HostDeviceVector *p_out_position) { - monitor_->Start(__func__); - - std::vector *gpair_ptr = &(gpair->HostVector()); - // in case 'num_parallel_trees != 1' no posibility to change initial gpair - if (GetNumberOfTrees() != 1) { - gpair_local_.resize(gpair_ptr->size()); - gpair_local_ = *gpair_ptr; - gpair_ptr = &gpair_local_; + void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree, + std::vector *best_splits) { + monitor_->Start(__func__); + std::vector hists; + for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) { + hists.push_back(&histogram_builder_[t].Histogram()); + } + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_.EvaluateSplits(*p_tree, hists, gmat.cut, best_splits); + break; + } + monitor_->Stop(__func__); } - this->InitData(p_fmat, *p_tree, gpair_ptr); - - ExpandTree(p_fmat, p_tree, *gpair_ptr, p_out_position); - monitor_->Stop(__func__); -} + void LeafPartition(RegTree const &tree, linalg::MatrixView gpair, + std::vector *p_out_position) { + monitor_->Start(__func__); + if (!task_.UpdateTreeLeaf()) { + return; + } + for (auto const &part : partitioner_) { + part.LeafPartition(ctx_, tree, gpair, p_out_position); + } + monitor_->Stop(__func__); + } -bool QuantileHistMaker::Builder::UpdatePredictionCache(DMatrix const *data, - linalg::VectorView out_preds) const { - // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in - // conjunction with Update(). - if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { - return false; + public: + explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam param, + std::shared_ptr column_sampler, + ObjInfo task, common::Monitor *monitor) + : UpdateTreeMixIn{std::move(param), monitor}, + col_sampler_{std::move(column_sampler)}, + evaluator_{ctx, info, param, col_sampler_}, + ctx_{ctx}, + task_{task} {} +}; + +struct HistBuilder : public UpdateTreeMixIn { + public: + // constructor + explicit HistBuilder(Context const *ctx, std::shared_ptr column_sampler, + const TrainParam ¶m, DMatrix const *fmat, ObjInfo task, + common::Monitor *monitor) + : UpdateTreeMixIn{param, monitor}, + evaluator_{param, fmat->Info(), ctx->Threads(), std::move(column_sampler)}, + p_last_fmat_(fmat), + histogram_builder_{new HistogramBuilder}, + task_{task}, + ctx_{ctx} { + monitor_->Init("Quantile::Builder"); } - monitor_->Start(__func__); - CHECK_EQ(out_preds.Size(), data->Info().num_row_); - UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, out_preds); - monitor_->Stop(__func__); - return true; -} - -void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, - std::vector *gpair) { - monitor_->Start(__func__); - const auto &info = fmat.Info(); - auto& rnd = common::GlobalRandom(); - std::vector& gpair_ref = *gpair; - -#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG - std::bernoulli_distribution coin_flip(param_.subsample); - for (size_t i = 0; i < info.num_row_; ++i) { - if (!(gpair_ref[i].GetHess() >= 0.0f && coin_flip(rnd)) || gpair_ref[i].GetGrad() == 0.0f) { - gpair_ref[i] = GradientPair(0); + + bool UpdatePredictionCache(DMatrix const *data, linalg::VectorView out_preds) const { + // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in + // conjunction with Update(). + if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { + return false; } + monitor_->Start(__func__); + CHECK_EQ(out_preds.Size(), data->Info().num_row_); + UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, out_preds); + monitor_->Stop(__func__); + return true; } -#else - uint64_t initial_seed = rnd(); - - auto n_threads = static_cast(ctx_->Threads()); - const size_t discard_size = info.num_row_ / n_threads; - std::bernoulli_distribution coin_flip(param_.subsample); - - dmlc::OMPException exc; - #pragma omp parallel num_threads(n_threads) - { - exc.Run([&]() { - const size_t tid = omp_get_thread_num(); - const size_t ibegin = tid * discard_size; - const size_t iend = (tid == (n_threads - 1)) ? info.num_row_ : ibegin + discard_size; - RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) { - return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng)); - }, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref); - }); - } - exc.Rethrow(); -#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG - monitor_->Stop(__func__); -} -size_t QuantileHistMaker::Builder::GetNumberOfTrees() { return n_trees_; } - -void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, - std::vector *gpair) { - monitor_->Start(__func__); - const auto& info = fmat->Info(); - - { + + public: + // initialize temp data structure + void InitData(DMatrix *fmat, RegTree const *p_tree) { + monitor_->Start(__func__); + size_t page_id{0}; - int32_t n_total_bins{0}; + bst_bin_t n_total_bins{0}; partitioner_.clear(); for (auto const &page : fmat->GetBatches(HistBatch(param_))) { if (n_total_bins == 0) { @@ -316,22 +359,244 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, } histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, collective::IsDistributed()); + // store a pointer to the tree + p_last_tree_ = p_tree; + monitor_->Stop(__func__); + } + + void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree, + std::vector *best_splits) { + monitor_->Start(__func__); + auto const &histograms = histogram_builder_->Histogram(); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_.EvaluateSplits(histograms, gmat.cut, ft, *p_tree, best_splits); + break; + } + monitor_->Stop(__func__); + } + + void ApplyTreeSplit(CPUExpandEntry const &candidate, RegTree *p_tree) { + this->evaluator_.ApplyTreeSplit(candidate, p_tree); + } + + CPUExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView gpair, + RegTree *p_tree) { + CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0)); + + size_t page_id = 0; + auto space = ConstructHistSpace(partitioner_, {node}); + for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { + std::vector nodes_to_build{node}; + std::vector nodes_to_sub; + this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree, + partitioner_.at(page_id).Partitions(), nodes_to_build, + nodes_to_sub, gpair.Slice(linalg::All(), 0).Values()); + ++page_id; + } + + { + GradientPairPrecise grad_stat; + if (p_fmat->IsDense()) { + /** + * Specialized code for dense data: For dense data (with no missing value), the sum + * of gradient histogram is equal to snode[nid] + */ + auto const &gmat = *(p_fmat->GetBatches(HistBatch(param_)).begin()); + std::vector const &row_ptr = gmat.cut.Ptrs(); + CHECK_GE(row_ptr.size(), 2); + uint32_t const ibegin = row_ptr[0]; + uint32_t const iend = row_ptr[1]; + auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot]; + auto begin = hist.data(); + for (uint32_t i = ibegin; i < iend; ++i) { + GradientPairPrecise const &et = begin[i]; + grad_stat.Add(et.GetGrad(), et.GetHess()); + } + } else { + auto gpair_h = gpair.Slice(linalg::All(), 0).Values(); + for (auto const &grad : gpair_h) { + grad_stat.Add(grad.GetGrad(), grad.GetHess()); + } + collective::Allreduce(reinterpret_cast(&grad_stat), + 2); + } + + auto weight = evaluator_.InitRoot(GradStats{grad_stat}); + p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess(); + p_tree->Stat(RegTree::kRoot).base_weight = weight; + (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); + + std::vector entries{node}; + monitor_->Start("EvaluateSplits"); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_.EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries); + break; + } + monitor_->Stop("EvaluateSplits"); + node = entries.front(); + } + + return node; + } + + void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, + std::vector const &valid_candidates, + linalg::MatrixView gpair) { + std::vector nodes_to_build(valid_candidates.size()); + std::vector nodes_to_sub(valid_candidates.size()); + + size_t n_idx = 0; + for (auto const &c : valid_candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}; + nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}; + n_idx++; + } + + size_t page_id{0}; + auto space = ConstructHistSpace(partitioner_, nodes_to_build); + for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { + histogram_builder_->BuildHist(page_id, space, gidx, p_tree, + partitioner_.at(page_id).Partitions(), nodes_to_build, + nodes_to_sub, gpair.Values()); + ++page_id; + } + } + + void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree, + std::vector const &applied) { + monitor_->Start(__func__); + std::size_t page_id{0}; + for (auto const &page : p_fmat->GetBatches(HistBatch(this->param_))) { + this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree); + page_id++; + } + monitor_->Stop(__func__); + } + + void LeafPartition(RegTree const &tree, linalg::MatrixView gpair, + std::vector *p_out_position) { + monitor_->Start(__func__); + if (!task_.UpdateTreeLeaf()) { + return; + } + for (auto const &part : partitioner_) { + part.LeafPartition(ctx_, tree, gpair, p_out_position); + } + monitor_->Stop(__func__); + } + + private: + HistEvaluator evaluator_; + std::vector partitioner_; + + // back pointers to tree and data matrix + const RegTree *p_last_tree_{nullptr}; + DMatrix const *const p_last_fmat_; + + std::unique_ptr> histogram_builder_; + ObjInfo task_; + // Context for number of threads + Context const *ctx_; +}; + +/*! \brief construct a tree using quantized feature values */ +class QuantileHistMaker : public TreeUpdater { + TrainParam param_; + std::unique_ptr p_impl_; + std::unique_ptr p_mtimpl_; + std::shared_ptr column_sampler_ = + std::make_shared(); + common::Monitor monitor_; + ObjInfo task_; + + public: + explicit QuantileHistMaker(Context const *ctx, ObjInfo task) : TreeUpdater{ctx}, task_{task} {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); } + + void LoadConfig(Json const &in) override { + auto const &config = get(in); + FromJson(config.at("train_param"), &this->param_); + } + void SaveConfig(Json *p_out) const override { + auto &out = *p_out; + out["train_param"] = ToJson(param_); + } + + char const *Name() const override { return "grow_quantile_histmaker"; } + + void Update(HostDeviceVector *gpair, DMatrix *p_fmat, + common::Span> out_position, + const std::vector &trees) override { + // rescale learning rate according to size of trees + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + + if (trees.front()->IsMultiTarget()) { + CHECK(param_.monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented(); + if (!p_mtimpl_) { + this->p_mtimpl_ = std::make_unique( + ctx_, p_fmat->Info(), param_, column_sampler_, task_, &monitor_); + } + } else { + if (!p_impl_) { + p_impl_.reset(new HistBuilder(ctx_, column_sampler_, param_, p_fmat, task_, &monitor_)); + } + } + + bst_target_t n_targets = trees.front()->NumTargets(); + auto h_gpair = + linalg::MakeTensorView(ctx_, gpair->HostSpan(), p_fmat->Info().num_row_, n_targets); + + linalg::Matrix sample_out; + auto h_sample_out = h_gpair; + if (trees.size() > 1 || n_targets > 1) { + sample_out = decltype(sample_out){h_gpair.Shape(), linalg::Order::kF, ctx_->gpu_id}; + h_sample_out = sample_out.HostView(); + } - if (param_.subsample < 1.0f) { - CHECK_EQ(param_.sampling_method, TrainParam::kUniform) - << "Only uniform sampling is supported, " - << "gradient-based sampling is only support by GPU Hist."; - InitSampling(*fmat, gpair); + for (auto tree_it = trees.begin(); tree_it != trees.end(); ++tree_it) { + if (trees.size() > 1 || n_targets > 1) { + common::ParallelFor(h_gpair.Size(), ctx_->Threads(), [&](auto i) { + linalg::detail::Apply(h_sample_out, linalg::UnravelIndex(i, h_gpair.Shape())) = + linalg::detail::Apply(h_gpair, linalg::UnravelIndex(i, h_gpair.Shape())); + }); + } + SampleGradient(ctx_, param_, h_sample_out); + auto *h_out_position = &out_position[tree_it - trees.begin()]; + if ((*tree_it)->IsMultiTarget()) { + this->p_mtimpl_->UpdateTree(h_sample_out, p_fmat, *tree_it, h_out_position); + } else { + this->p_impl_->UpdateTree(h_sample_out, p_fmat, *tree_it, h_out_position); + } } + + param_.learning_rate = lr; } - // store a pointer to the tree - p_last_tree_ = &tree; - evaluator_.reset( - new HistEvaluator{param_, info, this->ctx_->Threads(), column_sampler_}); + bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { + if (p_impl_) { + return p_impl_->UpdatePredictionCache(data, out_preds); + } else if (p_mtimpl_) { + // Not yet supported. + return false; + } else { + return false; + } + } - monitor_->Stop(__func__); -} + bool HasNodePosition() const override { return true; } +}; XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index dfb9c45b0b37..94b10aa40571 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -10,179 +10,29 @@ #include #include -#include -#include -#include -#include -#include +#include // shared_ptr +#include // move +#include // vector -#include "xgboost/base.h" -#include "xgboost/data.h" -#include "xgboost/json.h" - -#include "hist/evaluate_splits.h" -#include "hist/histogram.h" -#include "hist/expand_entry.h" - -#include "common_row_partitioner.h" -#include "constraints.h" -#include "./param.h" -#include "./driver.h" -#include "../common/random.h" -#include "../common/timer.h" #include "../common/hist_util.h" +#include "../common/random.h" #include "../common/row_set.h" -#include "../common/partition_builder.h" -#include "../common/column_matrix.h" +#include "../common/timer.h" // Monitor +#include "./driver.h" +#include "./param.h" +#include "common_row_partitioner.h" // CommonRowPartitioner +#include "constraints.h" +#include "hist/evaluate_splits.h" // HistEvaluator +#include "hist/expand_entry.h" +#include "hist/histogram.h" +#include "xgboost/base.h" +#include "xgboost/data.h" // BatchParam namespace xgboost { -struct RandomReplace { - public: - // similar value as for minstd_rand - static constexpr uint64_t kBase = 16807; - static constexpr uint64_t kMod = static_cast(1) << 63; - - using EngineT = std::linear_congruential_engine; - - /* - Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation - */ - static uint64_t SimpleSkip(uint64_t exponent, uint64_t initial_seed, - uint64_t base, uint64_t mod) { - CHECK_LE(exponent, mod); - uint64_t result = 1; - while (exponent > 0) { - if (exponent % 2 == 1) { - result = (result * base) % mod; - } - base = (base * base) % mod; - exponent = exponent >> 1; - } - // with result we can now find the new seed - return (result * initial_seed) % mod; - } - - template - static void MakeIf(Condition condition, const typename ContainerData::value_type replace_value, - const uint64_t initial_seed, const size_t ibegin, - const size_t iend, ContainerData* gpair) { - ContainerData& gpair_ref = *gpair; - const uint64_t displaced_seed = SimpleSkip(ibegin, initial_seed, kBase, kMod); - EngineT eng(displaced_seed); - for (size_t i = ibegin; i < iend; ++i) { - if (condition(i, eng)) { - gpair_ref[i] = replace_value; - } - } - } -}; - namespace tree { inline BatchParam HistBatch(TrainParam const& param) { return {param.max_bin, param.sparse_threshold}; } - -/*! \brief construct a tree using quantized feature values */ -class QuantileHistMaker: public TreeUpdater { - public: - explicit QuantileHistMaker(Context const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {} - void Configure(const Args& args) override; - - void Update(HostDeviceVector* gpair, DMatrix* dmat, - common::Span> out_position, - const std::vector& trees) override; - - bool UpdatePredictionCache(const DMatrix *data, - linalg::VectorView out_preds) override; - - void LoadConfig(Json const& in) override { - auto const& config = get(in); - FromJson(config.at("train_param"), &this->param_); - } - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["train_param"] = ToJson(param_); - } - - char const* Name() const override { - return "grow_quantile_histmaker"; - } - - bool HasNodePosition() const override { return true; } - - protected: - // training parameter - TrainParam param_; - - // actual builder that runs the algorithm - struct Builder { - public: - // constructor - explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat, - ObjInfo task, Context const* ctx) - : n_trees_(n_trees), - param_(param), - p_last_fmat_(fmat), - histogram_builder_{new HistogramBuilder}, - task_{task}, - ctx_{ctx}, - monitor_{std::make_unique()} { - monitor_->Init("Quantile::Builder"); - } - // update one tree, growing - void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree, - HostDeviceVector* p_out_position); - - bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView out_preds) const; - - private: - // initialize temp data structure - void InitData(DMatrix* fmat, const RegTree& tree, std::vector* gpair); - - size_t GetNumberOfTrees(); - - void InitSampling(const DMatrix& fmat, std::vector* gpair); - - CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree, - const std::vector& gpair_h); - - void BuildHistogram(DMatrix* p_fmat, RegTree* p_tree, - std::vector const& valid_candidates, - std::vector const& gpair); - - void LeafPartition(RegTree const& tree, common::Span gpair, - std::vector* p_out_position); - - void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h, - HostDeviceVector* p_out_position); - - private: - const size_t n_trees_; - const TrainParam& param_; - std::shared_ptr column_sampler_{ - std::make_shared()}; - - std::vector gpair_local_; - - std::unique_ptr> evaluator_; - std::vector partitioner_; - - // back pointers to tree and data matrix - const RegTree* p_last_tree_{nullptr}; - DMatrix const* const p_last_fmat_; - - std::unique_ptr> histogram_builder_; - ObjInfo task_; - // Context for number of threads - Context const* ctx_; - - std::unique_ptr monitor_; - }; - - protected: - std::unique_ptr pimpl_; - ObjInfo task_; -}; } // namespace tree } // namespace xgboost diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index 3da4c482c85d..040f1eefa73b 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -6,7 +6,9 @@ #include #include -#include +#include // size_t +#include // iota +#include #include "../../../src/common/linalg_op.h" @@ -48,10 +50,11 @@ TEST(Linalg, VectorView) { } TEST(Linalg, TensorView) { + Context ctx; std::vector data(2 * 3 * 4, 0); std::iota(data.begin(), data.end(), 0); - auto t = MakeTensorView(data, {2, 3, 4}, -1); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); ASSERT_EQ(t.Shape()[0], 2); ASSERT_EQ(t.Shape()[1], 3); ASSERT_EQ(t.Shape()[2], 4); @@ -106,12 +109,12 @@ TEST(Linalg, TensorView) { { // Don't assign the initial dimension, tensor should be able to deduce the correct dim // for Slice. - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(1, 2, All()); static_assert(decltype(s)::kDimension == 1, ""); } { - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(1, linalg::All(), 1); ASSERT_EQ(s(0), 13); ASSERT_EQ(s(1), 17); @@ -119,7 +122,7 @@ TEST(Linalg, TensorView) { } { // range slice - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(linalg::All(), linalg::Range(1, 3), 2); static_assert(decltype(s)::kDimension == 2, ""); std::vector sol{6, 10, 18, 22}; @@ -134,7 +137,7 @@ TEST(Linalg, TensorView) { } { // range slice - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(1, linalg::Range(1, 3), linalg::Range(1, 3)); static_assert(decltype(s)::kDimension == 2, ""); std::vector sol{17, 18, 21, 22}; @@ -149,7 +152,7 @@ TEST(Linalg, TensorView) { } { // same as no slice. - auto t = MakeTensorView(data, {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4)); static_assert(decltype(s)::kDimension == 3, ""); auto all = t.Slice(linalg::All(), linalg::All(), linalg::All()); @@ -166,7 +169,7 @@ TEST(Linalg, TensorView) { { // copy and move constructor. - auto t = MakeTensorView(data, {2, 3, 4}, kCpuId); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto from_copy = t; auto from_move = std::move(t); for (size_t i = 0; i < t.Shape().size(); ++i) { @@ -177,7 +180,7 @@ TEST(Linalg, TensorView) { { // multiple slices - auto t = MakeTensorView(data, {2, 3, 4}, kCpuId); + auto t = MakeTensorView(&ctx, data, 2, 3, 4); auto s_0 = t.Slice(linalg::All(), linalg::Range(0, 2), linalg::Range(1, 4)); ASSERT_FALSE(s_0.CContiguous()); auto s_1 = s_0.Slice(1, 1, linalg::Range(0, 2)); @@ -208,7 +211,7 @@ TEST(Linalg, TensorView) { TEST(Linalg, Tensor) { { - Tensor t{{2, 3, 4}, kCpuId}; + Tensor t{{2, 3, 4}, Order::kC, kCpuId}; auto view = t.View(kCpuId); auto const &as_const = t; @@ -227,7 +230,7 @@ TEST(Linalg, Tensor) { } { // Reshape - Tensor t{{2, 3, 4}, kCpuId}; + Tensor t{{2, 3, 4}, Order::kC, kCpuId}; t.Reshape(4, 3, 2); ASSERT_EQ(t.Size(), 24); ASSERT_EQ(t.Shape(2), 2); @@ -245,7 +248,7 @@ TEST(Linalg, Tensor) { TEST(Linalg, Empty) { { - auto t = TensorView{{}, {0, 3}, kCpuId}; + auto t = TensorView{{}, {0, 3}, Order::kC, kCpuId}; for (int32_t i : {0, 1, 2}) { auto s = t.Slice(All(), i); ASSERT_EQ(s.Size(), 0); @@ -254,7 +257,7 @@ TEST(Linalg, Empty) { } } { - auto t = Tensor{{0, 3}, kCpuId}; + auto t = Tensor{{0, 3}, Order::kC, kCpuId}; ASSERT_EQ(t.Size(), 0); auto view = t.View(kCpuId); @@ -269,7 +272,7 @@ TEST(Linalg, Empty) { TEST(Linalg, ArrayInterface) { auto cpu = kCpuId; - auto t = Tensor{{3, 3}, cpu}; + auto t = Tensor{{3, 3}, Order::kC, cpu}; auto v = t.View(cpu); std::iota(v.Values().begin(), v.Values().end(), 0); auto arr = Json::Load(StringView{ArrayInterfaceStr(v)}); @@ -313,21 +316,49 @@ TEST(Linalg, Popc) { } TEST(Linalg, Stack) { - Tensor l{{2, 3, 4}, kCpuId}; + Tensor l{{2, 3, 4}, Order::kC, kCpuId}; ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(), [=](size_t i, float) { return i; }); - Tensor r_0{{2, 3, 4}, kCpuId}; + Tensor r_0{{2, 3, 4}, Order::kC, kCpuId}; ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(), [=](size_t i, float) { return i; }); Stack(&l, r_0); - Tensor r_1{{0, 3, 4}, kCpuId}; + Tensor r_1{{0, 3, 4}, Order::kC, kCpuId}; Stack(&l, r_1); ASSERT_EQ(l.Shape(0), 4); Stack(&r_1, l); ASSERT_EQ(r_1.Shape(0), l.Shape(0)); } + +TEST(Linalg, FOrder) { + std::size_t constexpr kRows = 16, kCols = 3; + std::vector data(kRows * kCols); + MatrixView mat{data, {kRows, kCols}, Order::kF, Context::kCpuId}; + float k{0}; + for (std::size_t i = 0; i < kRows; ++i) { + for (std::size_t j = 0; j < kCols; ++j) { + mat(i, j) = k; + k++; + } + } + auto column = mat.Slice(linalg::All(), 1); + ASSERT_TRUE(column.FContiguous()); + ASSERT_EQ(column.Stride(0), 1); + ASSERT_TRUE(column.CContiguous()); + k = 1; + for (auto it = linalg::cbegin(column); it != linalg::cend(column); ++it) { + ASSERT_EQ(*it, k); + k += kCols; + } + k = 1; + auto ptr = column.Values().data(); + for (auto it = ptr; it != ptr + kRows; ++it) { + ASSERT_EQ(*it, k); + k += kCols; + } +} } // namespace linalg } // namespace xgboost diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 14f89774b70a..1bc35f6130b1 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2021-2022 by XGBoost Contributors + * Copyright 2021-2023 by XGBoost Contributors */ #include @@ -11,7 +11,7 @@ namespace xgboost { namespace linalg { namespace { void TestElementWiseKernel() { - Tensor l{{2, 3, 4}, 0}; + Tensor l{{2, 3, 4}, linalg::Order::kC, 0}; { /** * Non-contiguous @@ -55,8 +55,10 @@ void TestElementWiseKernel() { } void TestSlice() { + Context ctx; + ctx.gpu_id = 1; thrust::device_vector data(2 * 3 * 4); - auto t = MakeTensorView(dh::ToSpan(data), {2, 3, 4}, 0); + auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4); dh::LaunchN(1, [=] __device__(size_t) { auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4)); auto all = t.Slice(linalg::All(), linalg::All(), linalg::All()); diff --git a/tests/cpp/common/test_stats.cc b/tests/cpp/common/test_stats.cc index 03e50a9846e9..21b725275f2c 100644 --- a/tests/cpp/common/test_stats.cc +++ b/tests/cpp/common/test_stats.cc @@ -12,7 +12,8 @@ namespace xgboost { namespace common { TEST(Stats, Quantile) { { - linalg::Tensor arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId); + linalg::Tensor arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, linalg::Order::kC, + Context::kCpuId); std::vector index{0, 2, 3, 4, 6}; auto h_arr = arr.HostView(); auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); }); @@ -37,8 +38,9 @@ TEST(Stats, Quantile) { } TEST(Stats, WeightedQuantile) { - linalg::Tensor arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId); - linalg::Tensor weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId); + linalg::Tensor arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, linalg::Order::kC, Context::kCpuId); + linalg::Tensor weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, linalg::Order::kC, + Context::kCpuId); auto h_arr = arr.HostView(); auto h_weight = weight.HostView(); @@ -58,7 +60,7 @@ TEST(Stats, WeightedQuantile) { } TEST(Stats, Median) { - linalg::Tensor values{{.0f, .0f, 1.f, 2.f}, {4}, Context::kCpuId}; + linalg::Tensor values{{.0f, .0f, 1.f, 2.f}, {4}, linalg::Order::kC, Context::kCpuId}; Context ctx; HostDeviceVector weights; auto m = Median(&ctx, values, weights); @@ -74,14 +76,14 @@ TEST(Stats, Median) { namespace { void TestMean(Context const* ctx) { std::size_t n{128}; - linalg::Vector data({n}, ctx->gpu_id); + linalg::Vector data({n}, linalg::Order::kC, ctx->gpu_id); auto h_v = data.HostView().Values(); std::iota(h_v.begin(), h_v.end(), .0f); auto nf = static_cast(n); float mean = nf * (nf - 1) / 2 / n; - linalg::Vector res{{1}, ctx->gpu_id}; + linalg::Vector res{{1}, linalg::Order::kC, ctx->gpu_id}; Mean(ctx, data, &res); auto h_res = res.HostView(); ASSERT_EQ(h_res.Size(), 1); diff --git a/tests/cpp/common/test_stats.cu b/tests/cpp/common/test_stats.cu index e71ec3efcb60..eddc0001c8f5 100644 --- a/tests/cpp/common/test_stats.cu +++ b/tests/cpp/common/test_stats.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2022 by XGBoost Contributors + * Copyright 2022-2023 by XGBoost Contributors */ #include @@ -19,10 +19,8 @@ namespace { class StatsGPU : public ::testing::Test { private: linalg::Tensor arr_{ - {1.f, 2.f, 3.f, 4.f, 5.f, - 2.f, 4.f, 5.f, 3.f, 1.f}, - {10}, 0}; - linalg::Tensor indptr_{{0, 5, 10}, {3}, 0}; + {1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 5.f, 3.f, 1.f}, {10}, linalg::Order::kC, 0}; + linalg::Tensor indptr_{{0, 5, 10}, {3}, linalg::Order::kC, 0}; HostDeviceVector resutls_; using TestSet = std::vector>; Context ctx_; @@ -44,7 +42,7 @@ class StatsGPU : public ::testing::Test { [=] __device__(size_t i) { return d_key(i); }); auto val_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); }); - linalg::Tensor weights{{10}, 0}; + linalg::Tensor weights{{10}, linalg::Order::kC, 0}; linalg::ElementWiseTransformDevice(weights.View(0), [=] XGBOOST_DEVICE(size_t, float) { return 1.0; }); auto w_it = weights.Data()->ConstDevicePointer(); diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 8fc8ff0178fd..3b99b648a191 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -257,7 +257,8 @@ TEST(SimpleDMatrix, Slice) { std::iota(upper.begin(), upper.end(), 1.0f); auto& margin = p_m->Info().base_margin_; - margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, Context::kCpuId}; + margin = + decltype(p_m->Info().base_margin_){{kRows, kClasses}, linalg::Order::kC, Context::kCpuId}; std::array ridxs {1, 3, 5}; std::unique_ptr out { p_m->Slice(ridxs) }; @@ -319,7 +320,8 @@ TEST(SimpleDMatrix, SliceCol) { std::iota(upper.begin(), upper.end(), 1.0f); auto& margin = p_m->Info().base_margin_; - margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, Context::kCpuId}; + margin = + decltype(p_m->Info().base_margin_){{kRows, kClasses}, linalg::Order::kC, Context::kCpuId}; size_t constexpr kSlicCols {4}; for (auto slice = 0; slice < 2; slice++) { diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index c96b9849775b..270eacf21710 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -412,7 +412,7 @@ std::pair TestModelSlice(std::string booster) { j++; } - // CHECK sliced model doesn't have dependency on old one + // CHECK sliced model doesn't have dependency on the old one learner.reset(); CHECK_EQ(sliced->GetNumFeature(), kCols); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 741a228cfa7e..23df6c33be2d 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -96,8 +96,8 @@ void CheckObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels = - xgboost::linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; + info.labels = xgboost::linalg::Tensor{ + labels.cbegin(), labels.cend(), {labels.size()}, xgboost::linalg::Order::kC, -1}; info.weights_.HostVector() = weights; CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess); @@ -132,8 +132,11 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels = xgboost::linalg::Tensor{ - labels.cbegin(), labels.cend(), {labels.size(), static_cast(1)}, -1}; + info.labels = xgboost::linalg::Tensor{labels.cbegin(), + labels.cend(), + {labels.size(), static_cast(1)}, + xgboost::linalg::Order::kC, + -1}; info.weights_.HostVector() = weights; info.group_ptr_ = groups; @@ -147,8 +150,9 @@ xgboost::bst_float GetMetricEval(xgboost::Metric* metric, std::vector groups) { return GetMultiMetricEval( metric, preds, - xgboost::linalg::Tensor{labels.begin(), labels.end(), {labels.size()}, -1}, weights, - groups); + xgboost::linalg::Tensor{ + labels.begin(), labels.end(), {labels.size()}, xgboost::linalg::Order::kC, -1}, + weights, groups); } double GetMultiMetricEval(xgboost::Metric* metric, @@ -555,8 +559,8 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, for (size_t i = 0; i < kRows; ++i) { labels[i] = i; } - p_dmat->Info().labels = - linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; + p_dmat->Info().labels = linalg::Tensor{ + labels.cbegin(), labels.cend(), {labels.size()}, linalg::Order::kC, -1}; HostDeviceVector gpair; auto& h_gpair = gpair.HostVector(); h_gpair.resize(kRows); diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 00c38e452ec0..ed816d09f884 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -22,6 +22,7 @@ #include "../../src/data/array_interface.h" #include "../../src/gbm/gbtree_model.h" #include "filesystem.h" // dmlc::TemporaryDirectory +#include "xgboost/linalg.h" #if defined(__CUDACC__) #define DeclareUnifiedTest(name) GPU ## name @@ -189,12 +190,10 @@ Json GetArrayInterface(HostDeviceVector *storage, size_t rows, size_t cols) { Json array_interface{Object()}; array_interface["data"] = std::vector(2); if (storage->DeviceCanRead()) { - array_interface["data"][0] = - Integer(reinterpret_cast(storage->ConstDevicePointer())); + array_interface["data"][0] = Integer{reinterpret_cast(storage->ConstDevicePointer())}; array_interface["stream"] = nullptr; } else { - array_interface["data"][0] = - Integer(reinterpret_cast(storage->ConstHostPointer())); + array_interface["data"][0] = Integer{reinterpret_cast(storage->ConstHostPointer())}; } array_interface["data"][1] = Boolean(false); @@ -452,10 +451,11 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv); * \brief Make learner model param */ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint32_t n_groups, - int32_t device = Context::kCpuId) { + bst_target_t n_targets = 1, int32_t device = Context::kCpuId) { size_t shape[1]{1}; - LearnerModelParam mparam(n_features, linalg::Tensor{{base_score}, shape, device}, - n_groups); + LearnerModelParam mparam(n_features, + linalg::Tensor{{base_score}, shape, linalg::Order::kC, device}, + n_groups, n_targets); return mparam; } diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc index 321f46cdca16..33aca4f7f3cc 100644 --- a/tests/cpp/metric/test_auc.cc +++ b/tests/cpp/metric/test_auc.cc @@ -21,7 +21,7 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) { // Invalid dataset MetaInfo info; - info.labels = linalg::Tensor{{0.0f, 0.0f}, {2}, -1}; + info.labels = linalg::Tensor{{0.0f, 0.0f}, {2}, linalg::Order::kC, -1}; float auc = metric->Eval({1, 1}, info); ASSERT_TRUE(std::isnan(auc)); *info.labels.Data() = HostDeviceVector{}; diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc index fde9e42f2020..b4b16d27672a 100644 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ b/tests/cpp/metric/test_elementwise_metric.cc @@ -292,7 +292,7 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { size_t n_samples = 32, n_targets = 8; - linalg::Tensor y{{n_samples, n_targets}, GPUIDX}; + linalg::Tensor y{{n_samples, n_targets}, linalg::Order::kC, GPUIDX}; auto &h_y = y.Data()->HostVector(); std::iota(h_y.begin(), h_y.end(), 0); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 4a3293dbe73d..9b68d239206c 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -103,7 +103,7 @@ TEST(GPUPredictor, ExternalMemoryTest) { for (const auto& dmat: dmats) { dmat->Info().base_margin_ = decltype(dmat->Info().base_margin_){ - {dmat->Info().num_row_, static_cast(n_classes)}, 0}; + {dmat->Info().num_row_, static_cast(n_classes)}, linalg::Order::kC, 0}; dmat->Info().base_margin_.Data()->Fill(0.5); PredictionCacheEntry out_predictions; gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); @@ -150,7 +150,7 @@ TEST(GPUPredictor, ShapStump) { Context ctx; ctx.gpu_id = 0; - LearnerModelParam mparam{MakeMP(1, .5, 1, ctx.gpu_id)}; + LearnerModelParam mparam{MakeMP(1, .5, 1, 1, ctx.gpu_id)}; gbm::GBTreeModel model(&mparam, &ctx); std::vector> trees; @@ -177,7 +177,7 @@ TEST(GPUPredictor, ShapStump) { TEST(GPUPredictor, Shap) { Context ctx; ctx.gpu_id = 0; - LearnerModelParam mparam{MakeMP(1, .5, 1, ctx.gpu_id)}; + LearnerModelParam mparam{MakeMP(1, .5, 1, 1, ctx.gpu_id)}; gbm::GBTreeModel model(&mparam, &ctx); std::vector> trees; diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 292846c0ad92..79f2d829c51d 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -438,7 +438,7 @@ TEST(Learner, MultiTarget) { Json model{Object()}; learner->SaveModel(&model); - ASSERT_EQ(get(model["learner"]["learner_model_param"]["num_target"]), + ASSERT_EQ(get(model["learner"]["learner_model_param"]["num_class"]), std::to_string(kTargets)); } { diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 7000240df5ea..e257592dedfb 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -98,7 +98,8 @@ TEST(HistEvaluator, Apply) { auto sampler = std::make_shared(); auto evaluator_ = HistEvaluator{param, dmat->Info(), 4, sampler}; - CPUExpandEntry entry{0, 0, 10.0f}; + CPUExpandEntry entry{0, 0}; + entry.split.loss_chg = 10.0f; entry.split.left_sum = GradStats{0.4, 0.6f}; entry.split.right_sum = GradStats{0.5, 0.5f}; diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index d7d0f12cc12f..1c6a496484a3 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -40,10 +40,10 @@ void TestAddHistRows(bool is_distributed) { tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3), 0.0f); - nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4), 0.0f); - nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5), 0.0f); - nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3)); + nodes_for_explicit_hist_build_.emplace_back(4, tree.GetDepth(4)); + nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5)); + nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); HistogramBuilder histogram_builder; histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, @@ -97,7 +97,7 @@ void TestSyncHist(bool is_distributed) { } // level 0 - nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0)); histogram.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -107,10 +107,8 @@ void TestSyncHist(bool is_distributed) { nodes_for_subtraction_trick_.clear(); // level 1 - nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), - tree.GetDepth(1), 0.0f); - nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), - tree.GetDepth(2), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1)); + nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2)); histogram.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, @@ -122,10 +120,10 @@ void TestSyncHist(bool is_distributed) { nodes_for_explicit_hist_build_.clear(); nodes_for_subtraction_trick_.clear(); // level 2 - nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3), 0.0f); - nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4), 0.0f); - nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5), 0.0f); - nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); + nodes_for_explicit_hist_build_.emplace_back(3, tree.GetDepth(3)); + nodes_for_subtraction_trick_.emplace_back(4, tree.GetDepth(4)); + nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5)); + nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); histogram.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, @@ -251,7 +249,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); + CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)}; std::vector nodes_for_explicit_hist_build; nodes_for_explicit_hist_build.push_back(node); for (auto const &gidx : p_fmat->GetBatches({kMaxBins, 0.5})) { @@ -320,7 +318,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) { BatchParam batch_param{0, static_cast(kBins)}; RegTree tree; - CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); + CPUExpandEntry node{RegTree::kRoot, tree.GetDepth(0)}; std::vector nodes_for_explicit_hist_build; nodes_for_explicit_hist_build.push_back(node); @@ -392,7 +390,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo RegTree tree; std::vector nodes; - nodes.emplace_back(0, tree.GetDepth(0), 0.0f); + nodes.emplace_back(0, tree.GetDepth(0)); common::GHistRow multi_page; HistogramBuilder multi_build; diff --git a/tests/cpp/tree/hist/test_sampler.cc b/tests/cpp/tree/hist/test_sampler.cc new file mode 100644 index 000000000000..5d747f04bf48 --- /dev/null +++ b/tests/cpp/tree/hist/test_sampler.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include + +#include // std::size_t +#include // std::to_string + +#include "../../../../src/tree/hist/sampler.h" // SampleGradient +#include "../../../../src/tree/param.h" // TrainParam +#include "xgboost/base.h" // GradientPair,bst_target_t +#include "xgboost/context.h" // Context +#include "xgboost/data.h" // MetaInfo +#include "xgboost/linalg.h" // Matrix,Constants + +namespace xgboost { +namespace tree { +TEST(Sampler, Basic) { + std::size_t constexpr kRows = 1024; + double constexpr kSubsample = .2; + TrainParam param; + param.UpdateAllowUnknown(Args{{"subsample", std::to_string(kSubsample)}}); + Context ctx; + + auto run = [&](bst_target_t n_targets) { + auto init = GradientPair{1.0f, 1.0f}; + linalg::Matrix gpair = linalg::Constant(&ctx, init, kRows, n_targets); + auto h_gpair = gpair.HostView(); + SampleGradient(&ctx, param, h_gpair); + std::size_t n_sampled{0}; + for (std::size_t i = 0; i < kRows; ++i) { + bool sampled{false}; + if (h_gpair(i, 0).GetGrad() - .0f != .0f) { + sampled = true; + n_sampled++; + } + for (bst_target_t t = 1; t < n_targets; ++t) { + if (sampled) { + ASSERT_EQ(h_gpair(i, t).GetGrad() - init.GetGrad(), .0f); + ASSERT_EQ(h_gpair(i, t).GetHess() - init.GetHess(), .0f); + + } else { + ASSERT_EQ(h_gpair(i, t).GetGrad() - .0f, .0f); + ASSERT_EQ(h_gpair(i, t).GetHess() - .0f, .0f); + } + } + } + auto ratio = static_cast(n_sampled) / static_cast(kRows); + ASSERT_LT(ratio, kSubsample * 1.5); + ASSERT_GT(ratio, kSubsample * 0.5); + }; + + run(1); + run(3); +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 0b2d95100e21..6d25f04d4198 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -20,7 +20,8 @@ TEST(Approx, Partitioner) { auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); ctx.InitAllowUnknown(Args{}); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; auto grad = GenerateRandomGradients(n_samples); std::vector hess(grad.Size()); @@ -74,7 +75,8 @@ void TestLeafPartition(size_t n_samples) { CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; RegTree tree; std::vector hess(n_samples, 0); // emulate sampling diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 23cb868ee68f..0a66a521096f 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -29,7 +29,8 @@ TEST(QuantileHist, Partitioner) { ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - std::vector candidates{{0, 0, 0.4}}; + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads()); diff --git a/tests/python/test_model_compatibility.py b/tests/python/test_model_compatibility.py index a46715e42001..2102914fdca0 100644 --- a/tests/python/test_model_compatibility.py +++ b/tests/python/test_model_compatibility.py @@ -28,11 +28,11 @@ def run_booster_check(booster, name): 'objective'] == 'multi:softmax' elif name.find('logitraw') != -1: assert len(booster.get_dump()) == gm.kForests * gm.kRounds - assert config['learner']['learner_model_param']['num_class'] == str(0) + assert config['learner']['learner_model_param']['num_class'] == str(1) assert config['learner']['learner_train_param']['objective'] == 'binary:logitraw' elif name.find('logit') != -1: assert len(booster.get_dump()) == gm.kForests * gm.kRounds - assert config['learner']['learner_model_param']['num_class'] == str(0) + assert config['learner']['learner_model_param']['num_class'] == str(1) assert config['learner']['learner_train_param'][ 'objective'] == 'binary:logistic' elif name.find('ltr') != -1: diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 2fa20e284e7c..a9f961a2d75e 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1177,7 +1177,7 @@ def test_multilabel_classification() -> None: clf.fit(X, y) booster = clf.get_booster() learner = json.loads(booster.save_config())["learner"] - assert int(learner["learner_model_param"]["num_target"]) == 5 + assert int(learner["learner_model_param"]["num_class"]) == 5 np.testing.assert_allclose(clf.predict(X), y) predt = (clf.predict_proba(X) > 0.5).astype(np.int64)