From 425d8d8275ba65386c5b7d063b824cb75d21f43f Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sat, 19 Aug 2023 17:46:59 +0800 Subject: [PATCH 1/4] Use matrix for gradient. - Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major) --- R-package/R/utils.R | 7 +- R-package/src/init.c | 4 +- R-package/src/xgboost_R.cc | 43 ++--- R-package/src/xgboost_R.h | 3 +- demo/guide-python/custom_softmax.py | 4 +- demo/guide-python/multioutput_regression.py | 6 +- include/xgboost/base.h | 4 +- include/xgboost/c_api.h | 42 ++-- include/xgboost/gbm.h | 33 ++-- include/xgboost/learner.h | 21 +- include/xgboost/linalg.h | 48 ++++- include/xgboost/linear_updater.h | 5 +- include/xgboost/objective.h | 31 ++- include/xgboost/tree_updater.h | 2 +- .../java/ml/dmlc/xgboost4j/java/Booster.java | 26 ++- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 2 +- .../ml/dmlc/xgboost4j/scala/Booster.scala | 22 ++- .../xgboost4j/src/native/xgboost4j.cpp | 47 +++-- jvm-packages/xgboost4j/src/native/xgboost4j.h | 18 +- plugin/example/custom_obj.cc | 41 ++-- python-package/xgboost/core.py | 61 ++++-- python-package/xgboost/testing/__init__.py | 34 +++- src/c_api/c_api.cc | 89 ++++++--- src/c_api/c_api.cu | 28 ++- src/c_api/c_api_error.h | 8 +- src/c_api/c_api_utils.h | 51 ++++- src/data/array_interface.h | 4 +- src/data/data.cc | 2 +- src/gbm/gblinear.cc | 10 +- src/gbm/gbtree.cc | 50 ++--- src/gbm/gbtree.cu | 30 +-- src/gbm/gbtree.h | 6 +- src/learner.cc | 19 +- src/linear/updater_coordinate.cc | 15 +- src/linear/updater_gpu_coordinate.cu | 17 +- src/linear/updater_shotgun.cc | 25 ++- src/objective/aft_obj.cu | 19 +- src/objective/hinge.cu | 9 +- src/objective/init_estimation.cc | 2 +- src/objective/lambdarank_obj.cc | 58 +++--- src/objective/lambdarank_obj.cu | 31 +-- src/objective/lambdarank_obj.cuh | 6 +- src/objective/lambdarank_obj.h | 6 +- src/objective/multiclass_obj.cu | 17 +- src/objective/quantile_obj.cu | 30 +-- src/objective/regression_obj.cu | 180 ++++++++++-------- src/tree/fit_stump.cc | 7 +- src/tree/fit_stump.h | 2 +- src/tree/updater_approx.cc | 7 +- src/tree/updater_colmaker.cc | 5 +- src/tree/updater_gpu_hist.cu | 12 +- src/tree/updater_prune.cc | 2 +- src/tree/updater_quantile_hist.cc | 5 +- src/tree/updater_refresh.cc | 9 +- src/tree/updater_sync.cc | 2 +- tests/cpp/c_api/test_c_api.cc | 4 +- tests/cpp/gbm/test_gbtree.cc | 7 +- tests/cpp/helpers.cc | 21 +- tests/cpp/helpers.h | 32 ++-- tests/cpp/linear/test_linear.cc | 8 +- tests/cpp/linear/test_linear.cu | 7 +- tests/cpp/objective/test_aft_obj.cc | 19 +- tests/cpp/objective/test_lambdarank_obj.cc | 32 ++-- tests/cpp/objective/test_lambdarank_obj.cu | 29 +-- tests/cpp/objective/test_regression_obj.cc | 22 +-- tests/cpp/predictor/test_cpu_predictor.cc | 7 +- tests/cpp/test_multi_target.cc | 36 +++- tests/cpp/tree/test_fit_stump.cc | 18 +- tests/cpp/tree/test_gpu_hist.cu | 18 +- tests/cpp/tree/test_histmaker.cc | 26 ++- tests/cpp/tree/test_prediction_cache.cc | 2 +- tests/cpp/tree/test_prune.cc | 12 +- tests/cpp/tree/test_quantile_hist.cc | 18 +- tests/cpp/tree/test_refresh.cc | 8 +- tests/cpp/tree/test_tree_stat.cc | 17 +- tests/python-gpu/test_gpu_with_sklearn.py | 91 +++++++++ 76 files changed, 1048 insertions(+), 653 deletions(-) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 458b119f6b4c..b6b14c06f540 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -154,7 +154,12 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) { pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE, ntreelimit = 0) gpair <- obj(pred, dtrain) - .Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess) + n_samples <- dim(dtrain)[1] + gpair$grad <- matrix(gpair$grad, nrow = n_samples, byrow = TRUE) + gpair$hess <- matrix(gpair$hess, nrow = n_samples, byrow = TRUE) + .Call( + XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess + ) } return(TRUE) } diff --git a/R-package/src/init.c b/R-package/src/init.c index 583dc7e32613..09174222e4d2 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -16,7 +16,7 @@ Check these declarations against the C/Fortran source code. */ /* .Call calls */ -extern SEXP XGBoosterBoostOneIter_R(SEXP, SEXP, SEXP, SEXP); +extern SEXP XGBoosterTrainOneIter_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterCreate_R(SEXP); extern SEXP XGBoosterCreateInEmptyObj_R(SEXP, SEXP); extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP); @@ -53,7 +53,7 @@ extern SEXP XGBGetGlobalConfig_R(void); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); static const R_CallMethodDef CallEntries[] = { - {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4}, + {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5}, {"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1}, {"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2}, {"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 805e63a32e87..b975ab8ba76a 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -48,13 +48,6 @@ using dmlc::BeginPtr; -xgboost::Context const *BoosterCtx(BoosterHandle handle) { - CHECK_HANDLE(); - auto *learner = static_cast(handle); - CHECK(learner); - return learner->Ctx(); -} - xgboost::Context const *DMatrixCtx(DMatrixHandle handle) { CHECK_HANDLE(); auto p_m = static_cast *>(handle); @@ -394,21 +387,25 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) { return R_NilValue; } -XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) { +XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) { R_API_BEGIN(); - CHECK_EQ(length(grad), length(hess)) - << "gradient and hess must have same length"; - int len = length(grad); - std::vector tgrad(len), thess(len); - auto ctx = BoosterCtx(R_ExternalPtrAddr(handle)); - xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong j) { - tgrad[j] = REAL(grad)[j]; - thess[j] = REAL(hess)[j]; - }); - CHECK_CALL(XGBoosterBoostOneIter(R_ExternalPtrAddr(handle), - R_ExternalPtrAddr(dtrain), - BeginPtr(tgrad), BeginPtr(thess), - len)); + CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length"; + SEXP gdim = getAttrib(grad, R_DimSymbol); + auto n_samples = static_cast(INTEGER(gdim)[0]); + auto n_targets = static_cast(INTEGER(gdim)[1]); + + SEXP hdim = getAttrib(hess, R_DimSymbol); + CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian"; + CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian"; + double const *d_grad = REAL(grad); + double const *d_hess = REAL(hess); + + auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); + auto [s_grad, s_hess] = + xgboost::detail::MakeGradientInterface(ctx, d_grad, d_hess, n_samples, n_targets); + CHECK_CALL(XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain), + asInteger(iter), s_grad.c_str(), s_hess.c_str())); + R_API_END(); return R_NilValue; } @@ -460,7 +457,7 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con len *= out_shape[i]; } r_out_result = PROTECT(allocVector(REALSXP, len)); - auto ctx = BoosterCtx(R_ExternalPtrAddr(handle)); + auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) { REAL(r_out_result)[i] = out_result[i]; }); @@ -669,7 +666,7 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) { } out_scores_sexp = PROTECT(allocVector(REALSXP, len)); - auto ctx = BoosterCtx(R_ExternalPtrAddr(handle)); + auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) { REAL(out_scores_sexp)[i] = out_scores[i]; }); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 45a43a5bda2f..7f0833b157fe 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -161,12 +161,13 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain); * \brief update the model, by directly specify gradient and second order gradient, * this can be used to replace UpdateOneIter, to support customized loss function * \param handle handle + * \param iter The current training iteration. * \param dtrain training data * \param grad gradient statistics * \param hess second order gradient statistics * \return R_NilValue */ -XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess); +XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess); /*! * \brief get evaluation statistics for xgboost diff --git a/demo/guide-python/custom_softmax.py b/demo/guide-python/custom_softmax.py index 153c5d43b713..36265cf4d6c5 100644 --- a/demo/guide-python/custom_softmax.py +++ b/demo/guide-python/custom_softmax.py @@ -76,9 +76,7 @@ def softprob_obj(predt: np.ndarray, data: xgb.DMatrix): grad[r, c] = g hess[r, c] = h - # Right now (XGBoost 1.0.0), reshaping is necessary - grad = grad.reshape((kRows * kClasses, 1)) - hess = hess.reshape((kRows * kClasses, 1)) + # After 2.1.0, pass the gradient as it is. return grad, hess diff --git a/demo/guide-python/multioutput_regression.py b/demo/guide-python/multioutput_regression.py index 7450fd30aa57..a8a546b0c932 100644 --- a/demo/guide-python/multioutput_regression.py +++ b/demo/guide-python/multioutput_regression.py @@ -68,16 +68,14 @@ def rmse_model(plot_result: bool, strategy: str) -> None: def custom_rmse_model(plot_result: bool, strategy: str) -> None: """Train using Python implementation of Squared Error.""" - # As the experimental support status, custom objective doesn't support matrix as - # gradient and hessian, which will be changed in future release. def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: """Compute the gradient squared error.""" y = dtrain.get_label().reshape(predt.shape) - return (predt - y).reshape(y.size) + return predt - y def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: """Compute the hessian for squared error.""" - return np.ones(predt.shape).reshape(predt.size) + return np.ones(predt.shape) def squared_log( predt: np.ndarray, dtrain: xgb.DMatrix diff --git a/include/xgboost/base.h b/include/xgboost/base.h index f02d75cdc3a2..dec306f0cbb7 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -274,8 +274,8 @@ class GradientPairInt64 { GradientPairInt64(GradientPairInt64 const &g) = default; GradientPairInt64 &operator=(GradientPairInt64 const &g) = default; - XGBOOST_DEVICE [[nodiscard]] T GetQuantisedGrad() const { return grad_; } - XGBOOST_DEVICE [[nodiscard]] T GetQuantisedHess() const { return hess_; } + [[nodiscard]] XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; } + [[nodiscard]] XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; } XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) { grad_ += rhs.grad_; diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index fc60d2e77c8d..d2e14d752e0b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -789,16 +789,14 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, * \param out The address to hold number of rows. * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, - bst_ulong *out); +XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, bst_ulong *out); /*! * \brief get number of columns * \param handle the handle to the DMatrix * \param out The output of number of columns * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, - bst_ulong *out); +XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, bst_ulong *out); /*! * \brief Get number of valid values from DMatrix. @@ -945,21 +943,29 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, int iter, DMatrixHandle * @example c-api-demo.c */ -/*! - * \brief update the model, by directly specify gradient and second order gradient, - * this can be used to replace UpdateOneIter, to support customized loss function - * \param handle handle - * \param dtrain training data - * \param grad gradient statistics - * \param hess second order gradient statistics - * \param len length of grad/hess array - * \return 0 when success, -1 when failure happens +/** + * @deprecated since 2.1.0 */ -XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, - DMatrixHandle dtrain, - float *grad, - float *hess, - bst_ulong len); +XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, float *grad, + float *hess, bst_ulong len); + +/** + * @brief Update a multi-target model with gradient and Hessian. This is used for training + * with a custom objective function. + * + * @since 2.0.0 + * + * @param handle handle + * @param dtrain training data + * @param iter The current iteration number. + * @param grad Json encoded __(cuda)_array_interface__ for gradient. + * @param hess Json encoded __(cuda)_array_interface__ for Hessian. + * + * @return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterTrainOneIter(BoosterHandle handle, DMatrixHandle dtrain, int iter, + char const *grad, char const *hess); + /*! * \brief get evaluation statistics for xgboost * \param handle handle diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 6d383209368d..3667421b094c 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -70,22 +70,24 @@ class GradientBooster : public Model, public Configurable { GradientBooster* /*out*/, bool* /*out_of_bound*/) const { LOG(FATAL) << "Slice is not supported by the current booster."; } - /*! \brief Return number of boosted rounds. + /** + * @brief Return number of boosted rounds. */ - virtual int32_t BoostedRounds() const = 0; + [[nodiscard]] virtual int32_t BoostedRounds() const = 0; /** * \brief Whether the model has already been trained. When tree booster is chosen, then * returns true when there are existing trees. */ - virtual bool ModelFitted() const = 0; - /*! - * \brief perform update to the model(boosting) - * \param p_fmat feature matrix that provide access to features - * \param in_gpair address of the gradient pair statistics of the data - * \param prediction The output prediction cache entry that needs to be updated. - * the booster may change content of gpair + [[nodiscard]] virtual bool ModelFitted() const = 0; + /** + * @brief perform update to the model(boosting) + * + * @param p_fmat feature matrix that provide access to features + * @param in_gpair address of the gradient pair statistics of the data + * @param prediction The output prediction cache entry that needs to be updated. + * the booster may change content of gpair */ - virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, + virtual void DoBoost(DMatrix* p_fmat, linalg::Matrix* in_gpair, PredictionCacheEntry*, ObjFunction const* obj) = 0; /** @@ -165,18 +167,17 @@ class GradientBooster : public Model, public Configurable { * \param format the format to dump the model in * \return a vector of dump for boosters. */ - virtual std::vector DumpModel(const FeatureMap& fmap, - bool with_stats, - std::string format) const = 0; + [[nodiscard]] virtual std::vector DumpModel(const FeatureMap& fmap, bool with_stats, + std::string format) const = 0; virtual void FeatureScore(std::string const& importance_type, common::Span trees, std::vector* features, std::vector* scores) const = 0; - /*! - * \brief Whether the current booster uses GPU. + /** + * @brief Whether the current booster uses GPU. */ - virtual bool UseGPU() const = 0; + [[nodiscard]] virtual bool UseGPU() const = 0; /*! * \brief create a gradient booster from given name * \param name name of gradient booster diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 8adb3cb27e41..cd081a2e8227 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -76,17 +76,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param iter current iteration number * \param train reference to the data matrix. */ - virtual void UpdateOneIter(int iter, std::shared_ptr train) = 0; - /*! - * \brief Do customized gradient boosting with in_gpair. - * in_gair can be mutated after this call. - * \param iter current iteration number - * \param train reference to the data matrix. - * \param in_gpair The input gradient statistics. + virtual void UpdateOneIter(std::int32_t iter, std::shared_ptr train) = 0; + /** + * @brief Do customized gradient boosting with in_gpair. + * + * @note in_gpair can be mutated after this call. + * + * @param iter current iteration number + * @param train reference to the data matrix. + * @param in_gpair The input gradient statistics. */ - virtual void BoostOneIter(int iter, - std::shared_ptr train, - HostDeviceVector* in_gpair) = 0; + virtual void BoostOneIter(std::int32_t iter, std::shared_ptr train, + linalg::Matrix* in_gpair) = 0; /*! * \brief evaluate the model for specific iteration using the configured metrics. * \param iter iteration number diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 6d2b54f84c17..ae3489e3b217 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -292,7 +292,7 @@ enum Order : std::uint8_t { template class TensorView { public: - using ShapeT = size_t[kDim]; + using ShapeT = std::size_t[kDim]; using StrideT = ShapeT; private: @@ -400,10 +400,14 @@ class TensorView { * \param shape shape of the tensor * \param device Device ordinal */ - template + template LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device) : TensorView{data, shape, device, Order::kC} {} + template + LINALG_HD TensorView(common::Span data, I const (&shape)[D], DeviceOrd device) + : TensorView{data, shape, device.ordinal, Order::kC} {} + template LINALG_HD TensorView(common::Span data, I const (&shape)[D], std::int32_t device, Order order) : data_{data}, ptr_{data_.data()}, device_{device} { @@ -446,6 +450,10 @@ class TensorView { }); this->CalcSize(); } + template + LINALG_HD TensorView(common::Span data, I const (&shape)[D], I const (&stride)[D], + DeviceOrd device) + : TensorView{data, shape, stride, device.ordinal} {} template < typename U, @@ -741,7 +749,7 @@ auto ArrayInterfaceStr(TensorView const &t) { template class Tensor { public: - using ShapeT = size_t[kDim]; + using ShapeT = std::size_t[kDim]; using StrideT = ShapeT; private: @@ -775,6 +783,9 @@ class Tensor { template explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC) : Tensor{common::Span{shape}, device, order} {} + template + explicit Tensor(I const (&shape)[D], DeviceOrd device, Order order = kC) + : Tensor{common::Span{shape}, device.ordinal, order} {} template explicit Tensor(common::Span shape, std::int32_t device, Order order = kC) @@ -814,6 +825,10 @@ class Tensor { // shape this->Initialize(shape, device); } + template + explicit Tensor(std::initializer_list data, I const (&shape)[D], DeviceOrd device, + Order order = kC) + : Tensor{data, shape, device.ordinal, order} {} /** * \brief Index operator. Not thread safe, should not be used in performance critical * region. For more efficient indexing, consider getting a view first. @@ -832,9 +847,9 @@ class Tensor { } /** - * \brief Get a \ref TensorView for this tensor. + * @brief Get a @ref TensorView for this tensor. */ - TensorView View(int32_t device) { + TensorView View(std::int32_t device) { if (device >= 0) { data_.SetDevice(device); auto span = data_.DeviceSpan(); @@ -844,7 +859,7 @@ class Tensor { return {span, shape_, device, order_}; } } - TensorView View(int32_t device) const { + TensorView View(std::int32_t device) const { if (device >= 0) { data_.SetDevice(device); auto span = data_.ConstDeviceSpan(); @@ -854,6 +869,26 @@ class Tensor { return {span, shape_, device, order_}; } } + auto View(DeviceOrd device) { + if (device.IsCUDA()) { + data_.SetDevice(device); + auto span = data_.DeviceSpan(); + return TensorView{span, shape_, device.ordinal, order_}; + } else { + auto span = data_.HostSpan(); + return TensorView{span, shape_, device.ordinal, order_}; + } + } + auto View(DeviceOrd device) const { + if (device.IsCUDA()) { + data_.SetDevice(device); + auto span = data_.ConstDeviceSpan(); + return TensorView{span, shape_, device.ordinal, order_}; + } else { + auto span = data_.ConstHostSpan(); + return TensorView{span, shape_, device.ordinal, order_}; + } + } auto HostView() const { return this->View(-1); } auto HostView() { return this->View(-1); } @@ -931,6 +966,7 @@ class Tensor { * \brief Set device ordinal for this tensor. */ void SetDevice(int32_t device) const { data_.SetDevice(device); } + void SetDevice(DeviceOrd device) const { data_.SetDevice(device); } [[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); } }; diff --git a/include/xgboost/linear_updater.h b/include/xgboost/linear_updater.h index 6faf11230e1f..bcc8dd890dcd 100644 --- a/include/xgboost/linear_updater.h +++ b/include/xgboost/linear_updater.h @@ -49,9 +49,8 @@ class LinearUpdater : public Configurable { * \param model Model to be updated. * \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty. */ - virtual void Update(HostDeviceVector* in_gpair, DMatrix* data, - gbm::GBLinearModel* model, - double sum_instance_weight) = 0; + virtual void Update(linalg::Matrix* in_gpair, DMatrix* data, + gbm::GBLinearModel* model, double sum_instance_weight) = 0; /*! * \brief Create a linear updater given name diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index a04d2e453df7..d2623ee01df8 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -41,17 +41,16 @@ class ObjFunction : public Configurable { * \param args arguments to the objective function. */ virtual void Configure(const std::vector >& args) = 0; - /*! - * \brief Get gradient over each of predictions, given existing information. - * \param preds prediction of current round - * \param info information about labels, weights, groups in rank - * \param iteration current iteration number. - * \param out_gpair output of get gradient, saves gradient and second order gradient in + /** + * @brief Get gradient over each of predictions, given existing information. + * + * @param preds prediction of current round + * @param info information about labels, weights, groups in rank + * @param iteration current iteration number. + * @param out_gpair output of get gradient, saves gradient and second order gradient in */ - virtual void GetGradient(const HostDeviceVector& preds, - const MetaInfo& info, - int iteration, - HostDeviceVector* out_gpair) = 0; + virtual void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, + std::int32_t iter, linalg::Matrix* out_gpair) = 0; /*! \return the default evaluation metric for the objective */ virtual const char* DefaultEvalMetric() const = 0; @@ -81,9 +80,7 @@ class ObjFunction : public Configurable { * used by gradient boosting * \return transformed value */ - virtual bst_float ProbToMargin(bst_float base_score) const { - return base_score; - } + [[nodiscard]] virtual bst_float ProbToMargin(bst_float base_score) const { return base_score; } /** * \brief Make initialize estimation of prediction. * @@ -94,14 +91,14 @@ class ObjFunction : public Configurable { /*! * \brief Return task of this objective. */ - virtual struct ObjInfo Task() const = 0; + [[nodiscard]] virtual struct ObjInfo Task() const = 0; /** - * \brief Return number of targets for input matrix. Right now XGBoost supports only + * @brief Return number of targets for input matrix. Right now XGBoost supports only * multi-target regression. */ - virtual bst_target_t Targets(MetaInfo const& info) const { + [[nodiscard]] virtual bst_target_t Targets(MetaInfo const& info) const { if (info.labels.Shape(1) > 1) { - LOG(FATAL) << "multioutput is not supported by current objective function"; + LOG(FATAL) << "multioutput is not supported by the current objective function"; } return 1; } diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 79b80319f6da..477c8e4a1785 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -71,7 +71,7 @@ class TreeUpdater : public Configurable { * but maybe different random seeds, usually one tree is passed in at a time, * there can be multiple trees when we train random forest style model */ - virtual void Update(tree::TrainParam const* param, HostDeviceVector* gpair, + virtual void Update(tree::TrainParam const* param, linalg::Matrix* gpair, DMatrix* data, common::Span> out_position, const std::vector& out_trees) = 0; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 23b8b1a80d4c..11f5299c0b67 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -218,34 +218,48 @@ public void update(DMatrix dtrain, int iter) throws XGBoostError { XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); } + @Deprecated + public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { + float[][] predicts = this.predict(dtrain, true, 0, false, false); + List gradients = obj.getGradient(predicts, dtrain); + this.boost(dtrain, gradients.get(0), gradients.get(1)); + } + /** * Update with customize obj func * * @param dtrain training data + * @param iter The current training iteration. * @param obj customized objective class * @throws XGBoostError native error */ - public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { + public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError { float[][] predicts = this.predict(dtrain, true, 0, false, false); List gradients = obj.getGradient(predicts, dtrain); - boost(dtrain, gradients.get(0), gradients.get(1)); + this.boost(dtrain, iter, gradients.get(0), gradients.get(1)); + } + + @Deprecated + public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { + this.boost(dtrain, 0, grad, hess); } /** - * update with give grad and hess + * Update with give grad and hess * * @param dtrain training data + * @param iter The current training iteration. * @param grad first order of gradient * @param hess seconde order of gradient * @throws XGBoostError native error */ - public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { + public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError { if (grad.length != hess.length) { throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); } - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, - dtrain.getHandle(), grad, hess)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle, + dtrain.getHandle(), iter, grad, hess)); } /** diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index abe584f05fe4..d71d0a4f5c81 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -110,7 +110,7 @@ public final static native int XGDMatrixGetStrFeatureInfo(long handle, String fi public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain); - public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, + public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad, float[] hess); public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index a1d122679966..31be86898e5a 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) booster.update(dtrain.jDMatrix, iter) } + @throws(classOf[XGBoostError]) + @deprecated + def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { + booster.update(dtrain.jDMatrix, obj) + } + /** * update with customize obj func * * @param dtrain training data + * @param iter The current training iteration * @param obj customized objective class */ @throws(classOf[XGBoostError]) - def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { - booster.update(dtrain.jDMatrix, obj) + def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = { + booster.update(dtrain.jDMatrix, iter, obj) + } + + @throws(classOf[XGBoostError]) + @deprecated + def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { + booster.boost(dtrain.jDMatrix, grad, hess) } /** * update with give grad and hess * * @param dtrain training data + * @param iter The current training iteration * @param grad first order of gradient * @param hess seconde order of gradient */ @throws(classOf[XGBoostError]) - def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { - booster.boost(dtrain.jDMatrix, grad, hess) + def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = { + booster.boost(dtrain.jDMatrix, iter, grad, hess) } /** diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index a61a68dbcb88..60c2f126c61f 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -28,6 +28,7 @@ #include #include +#include "../../../src/c_api/c_api_error.h" #include "../../../src/c_api/c_api_utils.h" #define JVM_CHECK_CALL(__expr) \ @@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterBoostOneIter - * Signature: (JJ[F[F)V + * Method: XGBoosterTrainOneIter + * Signature: (JJI[F[F)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter - (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dtrain = (DMatrixHandle) jdtrain; - jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); - jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); - int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); - JVM_CHECK_CALL(ret); - //release +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter( + JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad, + jfloatArray jhess) { + API_BEGIN(); + BoosterHandle handle = reinterpret_cast(jhandle); + DMatrixHandle dtrain = reinterpret_cast(jdtrain); + CHECK(handle); + CHECK(dtrain); + bst_ulong n_samples{0}; + JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples)); + + bst_ulong len = static_cast(jenv->GetArrayLength(jgrad)); + jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr); + jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr); + CHECK(grad); + CHECK(hess); + + xgboost::bst_target_t n_targets{1}; + if (len != n_samples && n_samples != 0) { + CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient."; + n_targets = len / n_samples; + } + + auto ctx = xgboost::detail::BoosterCtx(handle); + auto [s_grad, s_hess] = + xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets); + int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast(jiter), s_grad.c_str(), + s_hess.c_str()); + + // release jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0); + return ret; + API_END(); } /* diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 11a2f86ffb82..b221c6a57da7 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterBoostOneIter - * Signature: (JJ[F[F)I + * Method: XGBoosterTrainOneIter + * Signature: (JJI[F[F)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter - (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter + (JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI @@ -386,19 +386,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterSetStrFeatureInfo - * Signature: (JLjava/lang/String;[Ljava/lang/String;])I + * Signature: (JLjava/lang/String;[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL -Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo (JNIEnv *, jclass, jlong, jstring, jobjectArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetStrFeatureInfo - * Signature: (JLjava/lang/String;[Ljava/lang/String;])I + * Signature: (JLjava/lang/String;[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL -Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo (JNIEnv *, jclass, jlong, jstring, jobjectArray); #ifdef __cplusplus diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index 3f18330ce6a2..b996447a3cd6 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2015-2022 by Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file custom_metric.cc * \brief This is an example to define plugin of xgboost. * This plugin defines the additional metric function. @@ -9,9 +9,7 @@ #include #include -namespace xgboost { -namespace obj { - +namespace xgboost::obj { // This is a helpful data structure to define parameters // You do not have to use it. // see http://dmlc-core.readthedocs.org/en/latest/parameter.html @@ -33,38 +31,38 @@ class MyLogistic : public ObjFunction { public: void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return ObjInfo::kRegression; } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int32_t /*iter*/, - HostDeviceVector* out_gpair) override { - out_gpair->Resize(preds.Size()); - const std::vector& preds_h = preds.HostVector(); - std::vector& out_gpair_h = out_gpair->HostVector(); + void GetGradient(const HostDeviceVector& preds, MetaInfo const& info, + std::int32_t /*iter*/, linalg::Matrix* out_gpair) override { + out_gpair->Reshape(info.num_row_, 1); + const std::vector& preds_h = preds.HostVector(); + auto out_gpair_h = out_gpair->HostView(); auto const labels_h = info.labels.HostView(); for (size_t i = 0; i < preds_h.size(); ++i) { - bst_float w = info.GetWeight(i); + float w = info.GetWeight(i); // scale the negative examples! if (labels_h(i) == 0.0f) w *= param_.scale_neg_weight; // logistic transformation - bst_float p = 1.0f / (1.0f + std::exp(-preds_h[i])); + float p = 1.0f / (1.0f + std::exp(-preds_h[i])); // this is the gradient - bst_float grad = (p - labels_h(i)) * w; + float grad = (p - labels_h(i)) * w; // this is the second order gradient - bst_float hess = p * (1.0f - p) * w; - out_gpair_h.at(i) = GradientPair(grad, hess); + float hess = p * (1.0f - p) * w; + out_gpair_h(i) = GradientPair(grad, hess); } } - const char* DefaultEvalMetric() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return "logloss"; } - void PredTransform(HostDeviceVector *io_preds) const override { + void PredTransform(HostDeviceVector *io_preds) const override { // transform margin value to probability. - std::vector &preds = io_preds->HostVector(); + std::vector &preds = io_preds->HostVector(); for (auto& pred : preds) { pred = 1.0f / (1.0f + std::exp(-pred)); } } - bst_float ProbToMargin(bst_float base_score) const override { + [[nodiscard]] float ProbToMargin(float base_score) const override { // transform probability to margin value return -std::log(1.0f / base_score - 1.0f); } @@ -89,5 +87,4 @@ XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic") .describe("User defined logistic regression plugin") .set_body([]() { return new MyLogistic(); }); -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d59d2f1d17f1..a9600c8fd331 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2053,12 +2053,14 @@ def update( else: pred = self.predict(dtrain, output_margin=True, training=True) grad, hess = fobj(pred, dtrain) - self.boost(dtrain, grad, hess) + self.boost(dtrain, iteration=iteration, grad=grad, hess=hess) - def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None: - """Boost the booster for one iteration, with customized gradient - statistics. Like :py:func:`xgboost.Booster.update`, this - function should not be called directly by users. + def boost( + self, dtrain: DMatrix, iteration: int, grad: NumpyOrCupy, hess: NumpyOrCupy + ) -> None: + """Boost the booster for one iteration with customized gradient statistics. + Like :py:func:`xgboost.Booster.update`, this function should not be called + directly by users. Parameters ---------- @@ -2070,19 +2072,52 @@ def boost(self, dtrain: DMatrix, grad: np.ndarray, hess: np.ndarray) -> None: The second order of gradient. """ - if len(grad) != len(hess): - raise ValueError(f"grad / hess length mismatch: {len(grad)} / {len(hess)}") - if not isinstance(dtrain, DMatrix): - raise TypeError(f"invalid training matrix: {type(dtrain).__name__}") + from .data import ( + _array_interface, + _cuda_array_interface, + _ensure_np_dtype, + _is_cupy_array, + ) + self._assign_dmatrix_features(dtrain) + def is_flatten(array: NumpyOrCupy) -> bool: + return len(array.shape) == 1 or array.shape[1] == 1 + + def array_interface(array: NumpyOrCupy) -> bytes: + msg = ( + "Expecting `np.ndarray` or `cupy.ndarray` for gradient and hessian." + f" Got: {type(array)}" + ) + if not isinstance(array, np.ndarray) and not _is_cupy_array(array): + raise TypeError(msg) + + n_samples = dtrain.num_row() + if array.shape[0] != n_samples and is_flatten(array): + warnings.warn( + "Since 2.1.0, the shape of the gradient and hessian is required to" + " be (n_samples, n_targets) or (n_samples, n_targets).", + FutureWarning, + ) + array = array.reshape(n_samples, array.size // n_samples) + + if isinstance(array, np.ndarray): + array, _ = _ensure_np_dtype(array, array.dtype) + interface = _array_interface(array) + elif _is_cupy_array(array): + interface = _cuda_array_interface(array) + else: + raise TypeError(msg) + + return interface + _check_call( - _LIB.XGBoosterBoostOneIter( + _LIB.XGBoosterTrainOneIter( self.handle, dtrain.handle, - c_array(ctypes.c_float, grad), - c_array(ctypes.c_float, hess), - c_bst_ulong(len(grad)), + iteration, + array_interface(grad), + array_interface(hess), ) ) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 8a21b60854a9..41fd6405ab24 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -763,13 +763,31 @@ def softmax(x: np.ndarray) -> np.ndarray: return e / np.sum(e) -def softprob_obj(classes: int) -> SklObjective: +def softprob_obj( + classes: int, use_cupy: bool = False, order: str = "C", gdtype: str = "float32" +) -> SklObjective: + """Custom softprob objective for testing. + + Parameters + ---------- + use_cupy : + Whether the objective should return cupy arrays. + order : + The order of gradient matrices. "C" or "F". + gdtype : + DType for gradient. Hessian is not set. This is for testing asymmetric types. + """ + if use_cupy: + import cupy as backend + else: + backend = np + def objective( - labels: np.ndarray, predt: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + labels: backend.ndarray, predt: backend.ndarray + ) -> Tuple[backend.ndarray, backend.ndarray]: rows = labels.shape[0] - grad = np.zeros((rows, classes), dtype=float) - hess = np.zeros((rows, classes), dtype=float) + grad = backend.zeros((rows, classes), dtype=np.float32) + hess = backend.zeros((rows, classes), dtype=np.float32) eps = 1e-6 for r in range(predt.shape[0]): target = labels[r] @@ -781,8 +799,10 @@ def objective( grad[r, c] = g hess[r, c] = h - grad = grad.reshape((rows * classes, 1)) - hess = hess.reshape((rows * classes, 1)) + grad = grad.reshape((rows, classes)) + hess = hess.reshape((rows, classes)) + grad = backend.require(grad, requirements=order, dtype=gdtype) + hess = backend.require(hess, requirements=order) return grad, hess return objective diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5b49d136f06f..2b0862d4945a 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -22,6 +22,7 @@ #include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... #include "../common/hist_util.h" // for HistogramCuts #include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... +#include "../common/linalg_op.h" // for ElementWiseTransformHost #include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor #include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... #include "../data/ellpack_page.h" // for EllpackPage @@ -68,6 +69,7 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) { } } +static_assert(DMLC_CXX11_THREAD_LOCAL, "XGBoost depends on thread-local storage."); using GlobalConfigAPIThreadLocalStore = dmlc::ThreadLocalStore; #if !defined(XGBOOST_USE_CUDA) @@ -717,8 +719,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, API_END(); } -XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle, - xgboost::bst_ulong *out) { +XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle, xgboost::bst_ulong *out) { API_BEGIN(); CHECK_HANDLE(); auto p_m = CastDMatrixHandle(handle); @@ -727,8 +728,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle, API_END(); } -XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle, - xgboost::bst_ulong *out) { +XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, xgboost::bst_ulong *out) { API_BEGIN(); CHECK_HANDLE(); auto p_m = CastDMatrixHandle(handle); @@ -970,28 +970,71 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, - DMatrixHandle dtrain, - bst_float *grad, - bst_float *hess, - xgboost::bst_ulong len) { +XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bst_float *grad, + bst_float *hess, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); - HostDeviceVector tmp_gpair; - auto* bst = static_cast(handle); - auto* dtr = - static_cast*>(dtrain); - tmp_gpair.Resize(len); - std::vector& tmp_gpair_h = tmp_gpair.HostVector(); - if (len > 0) { - xgboost_CHECK_C_ARG_PTR(grad); - xgboost_CHECK_C_ARG_PTR(hess); - } - for (xgboost::bst_ulong i = 0; i < len; ++i) { - tmp_gpair_h[i] = GradientPair(grad[i], hess[i]); - } + error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter"); + auto *learner = static_cast(handle); + auto ctx = learner->Ctx()->MakeCPU(); + + auto t_grad = linalg::MakeTensorView(&ctx, common::Span{grad, len}, len); + auto t_hess = linalg::MakeTensorView(&ctx, common::Span{hess, len}, len); - bst->BoostOneIter(0, *dtr, &tmp_gpair); + auto s_grad = linalg::ArrayInterfaceStr(t_grad); + auto s_hess = linalg::ArrayInterfaceStr(t_hess); + + return XGBoosterTrainOneIter(handle, dtrain, 0, s_grad.c_str(), s_hess.c_str()); + API_END(); +} + +namespace xgboost { +// copy user-supplied CUDA gradient arrays +void CopyGradientFromCUDAArrays(Context const *, ArrayInterface<2, false> const &, + ArrayInterface<2, false> const &, linalg::Matrix *) +#if !defined(XGBOOST_USE_CUDA) +{ + common::AssertGPUSupport(); +} +#else +; // NOLINT +#endif +} // namespace xgboost + +XGB_DLL int XGBoosterTrainOneIter(BoosterHandle handle, DMatrixHandle dtrain, int iter, + char const *grad, char const *hess) { + API_BEGIN(); + CHECK_HANDLE(); + xgboost_CHECK_C_ARG_PTR(grad); + xgboost_CHECK_C_ARG_PTR(hess); + auto p_fmat = CastDMatrixHandle(dtrain); + ArrayInterface<2, false> i_grad{StringView{grad}}; + ArrayInterface<2, false> i_hess{StringView{hess}}; + StringView msg{"Mismatched shape between the gradient and hessian."}; + CHECK_EQ(i_grad.Shape(0), i_hess.Shape(0)) << msg; + CHECK_EQ(i_grad.Shape(1), i_hess.Shape(1)) << msg; + linalg::Matrix gpair; + auto grad_is_cuda = ArrayInterfaceHandler::IsCudaPtr(i_grad.data); + auto hess_is_cuda = ArrayInterfaceHandler::IsCudaPtr(i_hess.data); + CHECK_EQ(i_grad.Shape(0), p_fmat->Info().num_row_) + << "Mismatched size between the gradient and training data."; + CHECK_EQ(grad_is_cuda, hess_is_cuda) << "gradient and hessian should be on the same device."; + auto *learner = static_cast(handle); + auto ctx = learner->Ctx(); + if (!grad_is_cuda) { + gpair.Reshape(i_grad.Shape(0), i_grad.Shape(1)); + auto const shape = gpair.Shape(); + auto h_gpair = gpair.HostView(); + DispatchDType(i_grad, DeviceOrd::CPU(), [&](auto &&t_grad) { + DispatchDType(i_hess, DeviceOrd::CPU(), [&](auto &&t_hess) { + common::ParallelFor(h_gpair.Size(), ctx->Threads(), + detail::CustomGradHessOp{t_grad, t_hess, h_gpair}); + }); + }); + } else { + CopyGradientFromCUDAArrays(ctx, i_grad, i_hess, &gpair); + } + learner->BoostOneIter(iter, p_fmat, &gpair); API_END(); } diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 964ab0c3fed4..21674f7857e5 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -1,8 +1,12 @@ /** * Copyright 2019-2023 by XGBoost Contributors */ -#include "../common/api_entry.h" // XGBAPIThreadLocalEntry +#include // for transform + +#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry +#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/threading_utils.h" +#include "../data/array_interface.h" // for DispatchDType, ArrayInterface #include "../data/device_adapter.cuh" #include "../data/proxy_dmatrix.h" #include "c_api_error.h" @@ -13,7 +17,6 @@ #include "xgboost/learner.h" namespace xgboost { - void XGBBuildInfoDevice(Json *p_info) { auto &info = *p_info; @@ -55,6 +58,27 @@ void XGBoostAPIGuard::RestoreGPUAttribute() { // If errors, do nothing, assuming running on CPU only machine. cudaSetDevice(device_id_); } + +void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> const &grad, + ArrayInterface<2, false> const &hess, + linalg::Matrix *out_gpair) { + auto grad_dev = dh::CudaGetPointerDevice(grad.data); + auto hess_dev = dh::CudaGetPointerDevice(hess.data); + CHECK_EQ(grad_dev, hess_dev) << "gradient and hessian should be on the same device."; + auto &gpair = *out_gpair; + gpair.SetDevice(grad_dev); + gpair.Reshape(grad.Shape(0), grad.Shape(1)); + auto d_gpair = gpair.View(grad_dev); + auto cuctx = ctx->CUDACtx(); + + DispatchDType(grad, DeviceOrd::CUDA(grad_dev), [&](auto &&t_grad) { + DispatchDType(hess, DeviceOrd::CUDA(hess_dev), [&](auto &&t_hess) { + CHECK_EQ(t_grad.Size(), t_hess.Size()); + thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), t_grad.Size(), + detail::CustomGradHessOp{t_grad, t_hess, d_gpair}); + }); + }); +} } // namespace xgboost using namespace xgboost; // NOLINT diff --git a/src/c_api/c_api_error.h b/src/c_api/c_api_error.h index 019bc1cf0710..11c4403847e0 100644 --- a/src/c_api/c_api_error.h +++ b/src/c_api/c_api_error.h @@ -1,5 +1,5 @@ -/*! - * Copyright (c) 2015-2022 by Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file c_api_error.h * \brief Error handling for C API. */ @@ -35,8 +35,8 @@ } \ return 0; // NOLINT(*) -#define CHECK_HANDLE() if (handle == nullptr) \ - LOG(FATAL) << "DMatrix/Booster has not been initialized or has already been disposed."; +#define CHECK_HANDLE() \ + if (handle == nullptr) ::xgboost::detail::EmptyHandle(); /*! * \brief Set the last error message needed by C API diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 1af0206be080..af43951c0a27 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -7,8 +7,10 @@ #include #include #include -#include // std::shared_ptr -#include +#include // for shared_ptr +#include // for string +#include // for make_tuple +#include // for move #include #include "xgboost/c_api.h" @@ -16,7 +18,7 @@ #include "xgboost/feature_map.h" // for FeatureMap #include "xgboost/json.h" #include "xgboost/learner.h" -#include "xgboost/linalg.h" // ArrayInterfaceHandler +#include "xgboost/linalg.h" // ArrayInterfaceHandler, MakeTensorView, ArrayInterfaceStr #include "xgboost/logging.h" #include "xgboost/string_view.h" // StringView @@ -287,6 +289,19 @@ inline std::shared_ptr CastDMatrixHandle(DMatrixHandle const handle) { } namespace detail { +inline void EmptyHandle() { + LOG(FATAL) << "DMatrix/Booster has not been initialized or has already been disposed."; +} + +inline xgboost::Context const *BoosterCtx(BoosterHandle handle) { + if (handle == nullptr) { + EmptyHandle(); + } + auto *learner = static_cast(handle); + CHECK(learner); + return learner->Ctx(); +} + template void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data, std::size_t nindptr, std::string *indptr_str, std::string *indices_str, @@ -334,6 +349,36 @@ void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data Json::Dump(jindices, indices_str); Json::Dump(jdata, data_str); } + +template +auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, std::size_t n_samples, + std::size_t n_targets) { + auto t_grad = + linalg::MakeTensorView(ctx, common::Span{grad, n_samples * n_targets}, n_samples, n_targets); + auto t_hess = + linalg::MakeTensorView(ctx, common::Span{hess, n_samples * n_targets}, n_samples, n_targets); + auto s_grad = linalg::ArrayInterfaceStr(t_grad); + auto s_hess = linalg::ArrayInterfaceStr(t_hess); + return std::make_tuple(s_grad, s_hess); +} + +template +struct CustomGradHessOp { + linalg::MatrixView t_grad; + linalg::MatrixView t_hess; + linalg::MatrixView d_gpair; + + CustomGradHessOp(linalg::MatrixView t_grad, linalg::MatrixView t_hess, + linalg::MatrixView d_gpair) + : t_grad{std::move(t_grad)}, t_hess{std::move(t_hess)}, d_gpair{std::move(d_gpair)} {} + + XGBOOST_DEVICE void operator()(std::size_t i) { + auto [m, n] = linalg::UnravelIndex(i, t_grad.Shape(0), t_grad.Shape(1)); + auto g = t_grad(m, n); + auto h = t_hess(m, n); + d_gpair(m, n) = GradientPair{static_cast(g), static_cast(h)}; + } +}; } // namespace detail } // namespace xgboost #endif // XGBOOST_C_API_C_API_UTILS_H_ diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 99effffef942..c62a5cef2e11 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -384,7 +384,7 @@ inline bool ArrayInterfaceHandler::IsCudaPtr(void const *) { return false; } * numpy has the proper support even though it's in the __cuda_array_interface__ * protocol defined by numba. */ -template +template class ArrayInterface { static_assert(D > 0, "Invalid dimension for array interface."); @@ -588,7 +588,7 @@ class ArrayInterface { }; template -void DispatchDType(ArrayInterface const array, std::int32_t device, Fn fn) { +void DispatchDType(ArrayInterface const array, DeviceOrd device, Fn fn) { // Only used for cuDF at the moment. CHECK_EQ(array.valid.Capacity(), 0); auto dispatch = [&](auto t) { diff --git a/src/data/data.cc b/src/data/data.cc index e8ecccb81d13..467770715e45 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -448,7 +448,7 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::TensorView(Context::kCpuId); CHECK(t_out.CContiguous()); auto const shape = t_out.Shape(); - DispatchDType(array, Context::kCpuId, [&](auto&& in) { + DispatchDType(array, DeviceOrd::CPU(), [&](auto&& in) { linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) { return std::apply(in, linalg::UnravelIndex(i, shape)); }); diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 520f76581b53..bf4f6b92f05a 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -29,7 +29,6 @@ #include "../common/error_msg.h" namespace xgboost::gbm { - DMLC_REGISTRY_FILE_TAG(gblinear); // training parameters @@ -142,7 +141,7 @@ class GBLinear : public GradientBooster { this->updater_->SaveConfig(&j_updater); } - void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry*, + void DoBoost(DMatrix* p_fmat, linalg::Matrix* in_gpair, PredictionCacheEntry*, ObjFunction const*) override { monitor_.Start("DoBoost"); @@ -232,9 +231,8 @@ class GBLinear : public GradientBooster { std::fill(contribs.begin(), contribs.end(), 0); } - std::vector DumpModel(const FeatureMap& fmap, - bool with_stats, - std::string format) const override { + [[nodiscard]] std::vector DumpModel(const FeatureMap& fmap, bool with_stats, + std::string format) const override { return model_.DumpModel(fmap, with_stats, format); } @@ -263,7 +261,7 @@ class GBLinear : public GradientBooster { } } - bool UseGPU() const override { + [[nodiscard]] bool UseGPU() const override { if (param_.updater == "gpu_coord_descent") { return true; } else { diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index e3df3862915c..e9c5be003c54 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -146,14 +146,6 @@ void GBTree::Configure(Args const& cfg) { if (specified_updater_) { error::WarnManualUpdater(); } - - if (model_.learner_model_param->IsVectorLeaf()) { - CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto) - << "Only the hist tree method is supported for building multi-target trees with vector " - "leaf."; - CHECK(ctx_->IsCPU()) << "GPU is not yet supported for vector leaf."; - } - LOG(DEBUG) << "Using tree method: " << static_cast(tparam_.tree_method); if (!specified_updater_) { @@ -175,8 +167,8 @@ void GBTree::Configure(Args const& cfg) { } } -void GPUCopyGradient(HostDeviceVector const*, bst_group_t, bst_group_t, - HostDeviceVector*) +void GPUCopyGradient(Context const*, linalg::Matrix const*, bst_group_t, + linalg::Matrix*) #if defined(XGBOOST_USE_CUDA) ; // NOLINT #else @@ -185,16 +177,19 @@ void GPUCopyGradient(HostDeviceVector const*, bst_group_t, bst_gro } #endif -void CopyGradient(HostDeviceVector const* in_gpair, int32_t n_threads, - bst_group_t n_groups, bst_group_t group_id, - HostDeviceVector* out_gpair) { - if (in_gpair->DeviceIdx() != Context::kCpuId) { - GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair); +void CopyGradient(Context const* ctx, linalg::Matrix const* in_gpair, + bst_group_t group_id, linalg::Matrix* out_gpair) { + out_gpair->SetDevice(ctx->Device()); + out_gpair->Reshape(in_gpair->Shape(0), 1); + if (ctx->IsCUDA()) { + GPUCopyGradient(ctx, in_gpair, group_id, out_gpair); } else { - std::vector &tmp_h = out_gpair->HostVector(); - const auto& gpair_h = in_gpair->ConstHostVector(); - common::ParallelFor(out_gpair->Size(), n_threads, - [&](auto i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; }); + auto const& in = *in_gpair; + auto target_gpair = in.Slice(linalg::All(), group_id); + auto h_tmp = out_gpair->HostView(); + auto h_in = in.HostView().Slice(linalg::All(), group_id); + CHECK_EQ(h_tmp.Size(), h_in.Size()); + common::ParallelFor(h_in.Size(), ctx->Threads(), [&](auto i) { h_tmp(i) = h_in(i); }); } } @@ -223,8 +218,15 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const } } -void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, +void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix* in_gpair, PredictionCacheEntry* predt, ObjFunction const* obj) { + if (model_.learner_model_param->IsVectorLeaf()) { + CHECK(tparam_.tree_method == TreeMethod::kHist || tparam_.tree_method == TreeMethod::kAuto) + << "Only the hist tree method is supported for building multi-target trees with vector " + "leaf."; + CHECK(ctx_->IsCPU()) << "GPU is not yet supported for vector leaf."; + } + TreesOneIter new_trees; bst_target_t const n_groups = model_.learner_model_param->OutputLength(); monitor_.Start("BoostNewTrees"); @@ -264,12 +266,12 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, } } else { CHECK_EQ(in_gpair->Size() % n_groups, 0U) << "must have exactly ngroup * nrow gpairs"; - HostDeviceVector tmp(in_gpair->Size() / n_groups, GradientPair(), - in_gpair->DeviceIdx()); + linalg::Matrix tmp{{in_gpair->Shape(0), static_cast(1ul)}, + ctx_->Ordinal()}; bool update_predict = true; for (bst_target_t gid = 0; gid < n_groups; ++gid) { node_position.clear(); - CopyGradient(in_gpair, ctx_->Threads(), n_groups, gid, &tmp); + CopyGradient(ctx_, in_gpair, gid, &tmp); TreesOneGroup ret; BoostNewTrees(&tmp, p_fmat, gid, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret); @@ -290,7 +292,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, this->CommitModel(std::move(new_trees)); } -void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fmat, int bst_group, +void GBTree::BoostNewTrees(linalg::Matrix* gpair, DMatrix* p_fmat, int bst_group, std::vector>* out_position, TreesOneGroup* ret) { std::vector new_trees; diff --git a/src/gbm/gbtree.cu b/src/gbm/gbtree.cu index c1972b2fcd75..8c4a960904f4 100644 --- a/src/gbm/gbtree.cu +++ b/src/gbm/gbtree.cu @@ -1,22 +1,24 @@ /** * Copyright 2021-2023, XGBoost Contributors */ -#include "../common/device_helpers.cuh" -#include "xgboost/linalg.h" -#include "xgboost/span.h" +#include // for make_counting_iterator + +#include "../common/cuda_context.cuh" +#include "../common/device_helpers.cuh" // for MakeTransformIterator +#include "xgboost/base.h" // for GradientPair +#include "xgboost/linalg.h" // for Matrix namespace xgboost::gbm { -void GPUCopyGradient(HostDeviceVector const *in_gpair, - bst_group_t n_groups, bst_group_t group_id, - HostDeviceVector *out_gpair) { - auto mat = linalg::TensorView( - in_gpair->ConstDeviceSpan(), - {in_gpair->Size() / n_groups, static_cast(n_groups)}, - in_gpair->DeviceIdx()); - auto v_in = mat.Slice(linalg::All(), group_id); - out_gpair->Resize(v_in.Size()); - auto d_out = out_gpair->DeviceSpan(); - dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in(i); }); +void GPUCopyGradient(Context const *ctx, linalg::Matrix const *in_gpair, + bst_group_t group_id, linalg::Matrix *out_gpair) { + auto v_in = in_gpair->View(ctx->Device()).Slice(linalg::All(), group_id); + out_gpair->SetDevice(ctx->Device()); + out_gpair->Reshape(v_in.Size(), 1); + auto d_out = out_gpair->View(ctx->Device()); + auto cuctx = ctx->CUDACtx(); + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v_in(i); }); + thrust::copy(cuctx->CTP(), it, it + v_in.Size(), d_out.Values().data()); } void GPUDartPredictInc(common::Span out_predts, diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 81e568368024..827d85217465 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -183,8 +183,8 @@ class GBTree : public GradientBooster { /** * @brief Carry out one iteration of boosting. */ - void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, - PredictionCacheEntry* predt, ObjFunction const* obj) override; + void DoBoost(DMatrix* p_fmat, linalg::Matrix* in_gpair, PredictionCacheEntry* predt, + ObjFunction const* obj) override; [[nodiscard]] bool UseGPU() const override { return tparam_.tree_method == TreeMethod::kGPUHist; } @@ -326,7 +326,7 @@ class GBTree : public GradientBooster { } protected: - void BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fmat, int bst_group, + void BoostNewTrees(linalg::Matrix* gpair, DMatrix* p_fmat, int bst_group, std::vector>* out_position, std::vector>* ret); diff --git a/src/learner.cc b/src/learner.cc index 81d1b795b0bc..be562f972af5 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1282,14 +1282,14 @@ class LearnerImpl : public LearnerIO { monitor_.Start("GetGradient"); GetGradient(predt.predictions, train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); - TrainingObserver::Instance().Observe(gpair_, "Gradients"); + TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients"); gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); monitor_.Stop("UpdateOneIter"); } void BoostOneIter(int iter, std::shared_ptr train, - HostDeviceVector* in_gpair) override { + linalg::Matrix* in_gpair) override { monitor_.Start("BoostOneIter"); this->Configure(); @@ -1299,6 +1299,9 @@ class LearnerImpl : public LearnerIO { this->ValidateDMatrix(train.get(), true); + CHECK_EQ(this->learner_model_param_.OutputLength(), in_gpair->Shape(1)) + << "The number of columns in gradient should be equal to the number of targets/classes in " + "the model."; auto& predt = prediction_container_.Cache(train, ctx_.gpu_id); gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get()); monitor_.Stop("BoostOneIter"); @@ -1461,18 +1464,18 @@ class LearnerImpl : public LearnerIO { } private: - void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, int iteration, - HostDeviceVector* out_gpair) { - out_gpair->Resize(preds.Size()); - collective::ApplyWithLabels(info, out_gpair->HostPointer(), + void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, + std::int32_t iter, linalg::Matrix* out_gpair) { + out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength()); + collective::ApplyWithLabels(info, out_gpair->Data()->HostPointer(), out_gpair->Size() * sizeof(GradientPair), - [&] { obj_->GetGradient(preds, info, iteration, out_gpair); }); + [&] { obj_->GetGradient(preds, info, iter, out_gpair); }); } /*! \brief random number transformation seed. */ static int32_t constexpr kRandSeedMagic = 127; // gradient pairs - HostDeviceVector gpair_; + linalg::Matrix gpair_; /*! \brief Temporary storage to prediction. Useful for storing data transformed by * objective function */ PredictionContainer output_predictions_; diff --git a/src/linear/updater_coordinate.cc b/src/linear/updater_coordinate.cc index f660a1be8504..0d61d7c7cb00 100644 --- a/src/linear/updater_coordinate.cc +++ b/src/linear/updater_coordinate.cc @@ -45,30 +45,31 @@ class CoordinateUpdater : public LinearUpdater { out["coordinate_param"] = ToJson(cparam_); } - void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, - gbm::GBLinearModel *model, double sum_instance_weight) override { + void Update(linalg::Matrix *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, + double sum_instance_weight) override { + auto gpair = in_gpair->Data(); tparam_.DenormalizePenalties(sum_instance_weight); const int ngroup = model->learner_model_param->num_output_group; // update bias for (int group_idx = 0; group_idx < ngroup; ++group_idx) { - auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->ConstHostVector(), p_fmat, + auto grad = GetBiasGradientParallel(group_idx, ngroup, gpair->ConstHostVector(), p_fmat, ctx_->Threads()); auto dbias = static_cast(tparam_.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->Bias()[group_idx] += dbias; - UpdateBiasResidualParallel(ctx_, group_idx, ngroup, dbias, &in_gpair->HostVector(), p_fmat); + UpdateBiasResidualParallel(ctx_, group_idx, ngroup, dbias, &gpair->HostVector(), p_fmat); } // prepare for updating the weights - selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm, + selector_->Setup(ctx_, *model, gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm, cparam_.top_k); // update weights for (int group_idx = 0; group_idx < ngroup; ++group_idx) { for (unsigned i = 0U; i < model->learner_model_param->num_feature; i++) { int fidx = - selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, + selector_->NextFeature(ctx_, i, *model, group_idx, gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm); if (fidx < 0) break; - this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model); + this->UpdateFeature(fidx, group_idx, &gpair->HostVector(), p_fmat, model); } } monitor_.Stop("UpdateFeature"); diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index b6c817696d84..659b45135cb5 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -93,17 +93,18 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT } } - void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, - gbm::GBLinearModel *model, double sum_instance_weight) override { + void Update(linalg::Matrix *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, + double sum_instance_weight) override { tparam_.DenormalizePenalties(sum_instance_weight); monitor_.Start("LazyInitDevice"); this->LazyInitDevice(p_fmat, *(model->learner_model_param)); monitor_.Stop("LazyInitDevice"); monitor_.Start("UpdateGpair"); + // Update gpair - if (ctx_->gpu_id >= 0) { - this->UpdateGpair(in_gpair->ConstHostVector()); + if (ctx_->IsCUDA()) { + this->UpdateGpair(in_gpair->Data()->ConstHostVector()); } monitor_.Stop("UpdateGpair"); @@ -111,15 +112,15 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT this->UpdateBias(model); monitor_.Stop("UpdateBias"); // prepare for updating the weights - selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm, - tparam_.reg_lambda_denorm, coord_param_.top_k); + selector_->Setup(ctx_, *model, in_gpair->Data()->ConstHostVector(), p_fmat, + tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm, coord_param_.top_k); monitor_.Start("UpdateFeature"); for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group; ++group_idx) { for (auto i = 0U; i < model->learner_model_param->num_feature; i++) { auto fidx = - selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, - tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm); + selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->Data()->ConstHostVector(), + p_fmat, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm); if (fidx < 0) break; this->UpdateFeature(fidx, group_idx, model); } diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc index 18b747f64985..78fb1fe1b690 100644 --- a/src/linear/updater_shotgun.cc +++ b/src/linear/updater_shotgun.cc @@ -6,8 +6,7 @@ #include #include "coordinate_common.h" -namespace xgboost { -namespace linear { +namespace xgboost::linear { DMLC_REGISTRY_FILE_TAG(updater_shotgun); @@ -32,30 +31,31 @@ class ShotgunUpdater : public LinearUpdater { out["linear_train_param"] = ToJson(param_); } - void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, - gbm::GBLinearModel *model, double sum_instance_weight) override { - auto &gpair = in_gpair->HostVector(); + void Update(linalg::Matrix *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, + double sum_instance_weight) override { + auto gpair = in_gpair->Data(); param_.DenormalizePenalties(sum_instance_weight); const int ngroup = model->learner_model_param->num_output_group; // update bias for (int gid = 0; gid < ngroup; ++gid) { - auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->ConstHostVector(), p_fmat, + auto grad = GetBiasGradientParallel(gid, ngroup, gpair->ConstHostVector(), p_fmat, ctx_->Threads()); auto dbias = static_cast(param_.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->Bias()[gid] += dbias; - UpdateBiasResidualParallel(ctx_, gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat); + UpdateBiasResidualParallel(ctx_, gid, ngroup, dbias, &gpair->HostVector(), p_fmat); } // lock-free parallel updates of weights - selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, + selector_->Setup(ctx_, *model, gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0); + auto &h_gpair = gpair->HostVector(); for (const auto &batch : p_fmat->GetBatches(ctx_)) { auto page = batch.GetView(); const auto nfeat = static_cast(batch.Size()); common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) { - int ii = selector_->NextFeature(ctx_, i, *model, 0, in_gpair->ConstHostVector(), p_fmat, + int ii = selector_->NextFeature(ctx_, i, *model, 0, gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, param_.reg_lambda_denorm); if (ii < 0) return; const bst_uint fid = ii; @@ -63,7 +63,7 @@ class ShotgunUpdater : public LinearUpdater { for (int gid = 0; gid < ngroup; ++gid) { double sum_grad = 0.0, sum_hess = 0.0; for (auto &c : col) { - const GradientPair &p = gpair[c.index * ngroup + gid]; + const GradientPair &p = h_gpair[c.index * ngroup + gid]; if (p.GetHess() < 0.0f) continue; const bst_float v = c.fvalue; sum_grad += p.GetGrad() * v; @@ -77,7 +77,7 @@ class ShotgunUpdater : public LinearUpdater { w += dw; // update grad values for (auto &c : col) { - GradientPair &p = gpair[c.index * ngroup + gid]; + GradientPair &p = h_gpair[c.index * ngroup + gid]; if (p.GetHess() < 0.0f) continue; p += GradientPair(p.GetHess() * c.fvalue * dw, 0); } @@ -98,5 +98,4 @@ XGBOOST_REGISTER_LINEAR_UPDATER(ShotgunUpdater, "shotgun") "Update linear model according to shotgun coordinate descent " "algorithm.") .set_body([]() { return new ShotgunUpdater(); }); -} // namespace linear -} // namespace xgboost +} // namespace xgboost::linear diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 52a58a7f4b0f..522866a4254c 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 by Contributors +/** + * Copyright 2019-2023, XGBoost Contributors * \file aft_obj.cu * \brief Definition of AFT loss for survival analysis. * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking @@ -41,11 +41,9 @@ class AFTObj : public ObjFunction { ObjInfo Task() const override { return ObjInfo::kSurvival; } template - void GetGradientImpl(const HostDeviceVector &preds, - const MetaInfo &info, - HostDeviceVector *out_gpair, - size_t ndata, int device, bool is_null_weight, - float aft_loss_distribution_scale) { + void GetGradientImpl(const HostDeviceVector& preds, const MetaInfo& info, + linalg::Matrix* out_gpair, size_t ndata, int device, + bool is_null_weight, float aft_loss_distribution_scale) { common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _out_gpair, @@ -66,16 +64,17 @@ class AFTObj : public ObjFunction { _out_gpair[_idx] = GradientPair(grad * w, hess * w); }, common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( - out_gpair, &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, + out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, &info.weights_); } void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int /*iter*/, - HostDeviceVector* out_gpair) override { + linalg::Matrix* out_gpair) override { const size_t ndata = preds.Size(); CHECK_EQ(info.labels_lower_bound_.Size(), ndata); CHECK_EQ(info.labels_upper_bound_.Size(), ndata); - out_gpair->Resize(ndata); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(ndata, 1); const int device = ctx_->gpu_id; const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale; const bool is_null_weight = info.weights_.Size() == 0; diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index bff3bc593a8d..0d3ed6ca4c04 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -27,8 +27,8 @@ class HingeObj : public ObjFunction { void Configure(Args const&) override {} ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int /*iter*/, - HostDeviceVector *out_gpair) override { + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, + std::int32_t /*iter*/, linalg::Matrix *out_gpair) override { CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided" @@ -41,7 +41,8 @@ class HingeObj : public ObjFunction { CHECK_EQ(info.weights_.Size(), ndata) << "Number of weights should be equal to number of data points."; } - out_gpair->Resize(ndata); + CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported."; + out_gpair->Reshape(ndata, 1); common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _out_gpair, @@ -63,7 +64,7 @@ class HingeObj : public ObjFunction { }, common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), ctx_->gpu_id).Eval( - out_gpair, &preds, info.labels.Data(), &info.weights_); + out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); } void PredTransform(HostDeviceVector *io_preds) const override { diff --git a/src/objective/init_estimation.cc b/src/objective/init_estimation.cc index 834c052f5609..47e0364fe1e5 100644 --- a/src/objective/init_estimation.cc +++ b/src/objective/init_estimation.cc @@ -21,7 +21,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* b } // Avoid altering any state in child objective. HostDeviceVector dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id); - HostDeviceVector gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id); + linalg::Matrix gpair(info.labels.Shape(), this->ctx_->gpu_id); Json config{Object{}}; this->SaveConfig(&config); diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index d0ff5bda5bde..46fd777056e9 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -165,9 +165,8 @@ class LambdaRankObj : public FitIntercept { void CalcLambdaForGroup(std::int32_t iter, common::Span g_predt, linalg::VectorView g_label, float w, common::Span g_rank, bst_group_t g, Delta delta, - common::Span g_gpair) { - std::fill_n(g_gpair.data(), g_gpair.size(), GradientPair{}); - auto p_gpair = g_gpair.data(); + linalg::VectorView g_gpair) { + std::fill_n(g_gpair.Values().data(), g_gpair.Size(), GradientPair{}); auto ti_plus = ti_plus_.HostView(); auto tj_minus = tj_minus_.HostView(); @@ -198,8 +197,8 @@ class LambdaRankObj : public FitIntercept { std::size_t idx_high = g_rank[rank_high]; std::size_t idx_low = g_rank[rank_low]; - p_gpair[idx_high] += pg; - p_gpair[idx_low] += ng; + g_gpair(idx_high) += pg; + g_gpair(idx_low) += ng; if (unbiased) { auto k = ti_plus.Size(); @@ -225,12 +224,13 @@ class LambdaRankObj : public FitIntercept { MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); if (sum_lambda > 0.0) { double norm = std::log2(1.0 + sum_lambda) / sum_lambda; - std::transform(g_gpair.data(), g_gpair.data() + g_gpair.size(), g_gpair.data(), - [norm](GradientPair const& g) { return g * norm; }); + std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(), + g_gpair.Values().data(), [norm](GradientPair const& g) { return g * norm; }); } auto w_norm = p_cache_->WeightNorm(); - std::transform(g_gpair.begin(), g_gpair.end(), g_gpair.begin(), + std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(), + g_gpair.Values().data(), [&](GradientPair const& gpair) { return gpair * w * w_norm; }); } @@ -301,7 +301,7 @@ class LambdaRankObj : public FitIntercept { } void GetGradient(HostDeviceVector const& predt, MetaInfo const& info, std::int32_t iter, - HostDeviceVector* out_gpair) override { + linalg::Matrix* out_gpair) override { CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize(); // init/renew cache @@ -339,7 +339,7 @@ class LambdaRankNDCG : public LambdaRankObj { void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span g_predt, linalg::VectorView g_label, float w, common::Span g_rank, - common::Span g_gpair, + linalg::VectorView g_gpair, linalg::VectorView inv_IDCG, common::Span discount, bst_group_t g) { auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low, @@ -351,7 +351,7 @@ class LambdaRankNDCG : public LambdaRankObj { } void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, - const MetaInfo& info, HostDeviceVector* out_gpair) { + const MetaInfo& info, linalg::Matrix* out_gpair) { if (ctx_->IsCUDA()) { cuda_impl::LambdaRankGetGradientNDCG( ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), @@ -363,8 +363,10 @@ class LambdaRankNDCG : public LambdaRankObj { bst_group_t n_groups = p_cache_->Groups(); auto gptr = p_cache_->DataGroupPtr(ctx_); - out_gpair->Resize(info.num_row_); - auto h_gpair = out_gpair->HostSpan(); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, 1); + + auto h_gpair = out_gpair->HostView(); auto h_predt = predt.ConstHostSpan(); auto h_label = info.labels.HostView(); auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); @@ -378,7 +380,8 @@ class LambdaRankNDCG : public LambdaRankObj { std::size_t cnt = gptr[g + 1] - gptr[g]; auto w = h_weight[g]; auto g_predt = h_predt.subspan(gptr[g], cnt); - auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_gpair = + h_gpair.Slice(linalg::Range(static_cast(gptr[g]), gptr[g] + cnt), 0); auto g_label = h_label.Slice(make_range(g), 0); auto g_rank = rank_idx.subspan(gptr[g], cnt); @@ -420,7 +423,7 @@ void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector, // input bias ratio linalg::VectorView, // input bias ratio linalg::VectorView, linalg::VectorView, - HostDeviceVector*) { + linalg::Matrix*) { common::AssertGPUSupport(); } @@ -470,7 +473,7 @@ void MAPStat(Context const* ctx, linalg::VectorView label, class LambdaRankMAP : public LambdaRankObj { public: void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, - const MetaInfo& info, HostDeviceVector* out_gpair) { + const MetaInfo& info, linalg::Matrix* out_gpair) { CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective."; if (ctx_->IsCUDA()) { return cuda_impl::LambdaRankGetGradientMAP( @@ -482,8 +485,11 @@ class LambdaRankMAP : public LambdaRankObj { auto gptr = p_cache_->DataGroupPtr(ctx_).data(); bst_group_t n_groups = p_cache_->Groups(); - out_gpair->Resize(info.num_row_); - auto h_gpair = out_gpair->HostSpan(); + CHECK_EQ(info.labels.Shape(1), 1) << "multi-target for learning to rank is not yet supported."; + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, this->Targets(info)); + + auto h_gpair = out_gpair->HostView(); auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto h_predt = predt.ConstHostSpan(); auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); @@ -514,7 +520,7 @@ class LambdaRankMAP : public LambdaRankObj { auto cnt = gptr[g + 1] - gptr[g]; auto w = h_weight[g]; auto g_predt = h_predt.subspan(gptr[g], cnt); - auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.Slice(linalg::Range(gptr[g], gptr[g] + cnt), 0); auto g_label = h_label.Slice(make_range(g)); auto g_rank = rank_idx.subspan(gptr[g], cnt); @@ -545,7 +551,7 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector, // input bias ratio linalg::VectorView, // input bias ratio linalg::VectorView, linalg::VectorView, - HostDeviceVector*) { + linalg::Matrix*) { common::AssertGPUSupport(); } } // namespace cuda_impl @@ -557,7 +563,7 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector { public: void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, - const MetaInfo& info, HostDeviceVector* out_gpair) { + const MetaInfo& info, linalg::Matrix* out_gpair) { CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective."; if (ctx_->IsCUDA()) { return cuda_impl::LambdaRankGetGradientPairwise( @@ -569,8 +575,10 @@ class LambdaRankPairwise : public LambdaRankObjDataGroupPtr(ctx_); bst_group_t n_groups = p_cache_->Groups(); - out_gpair->Resize(info.num_row_); - auto h_gpair = out_gpair->HostSpan(); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, this->Targets(info)); + + auto h_gpair = out_gpair->HostView(); auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto h_predt = predt.ConstHostSpan(); auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); @@ -585,7 +593,7 @@ class LambdaRankPairwise : public LambdaRankObj, // input bias ratio linalg::VectorView, // input bias ratio linalg::VectorView, linalg::VectorView, - HostDeviceVector*) { + linalg::Matrix*) { common::AssertGPUSupport(); } } // namespace cuda_impl diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 2a7cac7516f3..0f57fce48ef2 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -93,7 +93,7 @@ struct GetGradOp { // obtain group segment data. auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); auto g_predt = args.predts.subspan(data_group_begin, n_data); - auto g_gpair = args.gpairs.subspan(data_group_begin, n_data).data(); + auto g_gpair = args.gpairs.Slice(linalg::Range(data_group_begin, data_group_begin + n_data)); auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data); auto [i, j] = make_pair(idx, g); @@ -128,8 +128,8 @@ struct GetGradOp { auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()), common::TruncateWithRounding(gr.GetHess(), ng.GetHess())}; - dh::AtomicAddGpair(g_gpair + idx_high, pgt); - dh::AtomicAddGpair(g_gpair + idx_low, ngt); + dh::AtomicAddGpair(&g_gpair(idx_high), pgt); + dh::AtomicAddGpair(&g_gpair(idx_low), ngt); } if (unbiased && need_update) { @@ -266,16 +266,16 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptrWeightNorm(); - thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.size(), - [=] XGBOOST_DEVICE(std::size_t i) { + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.Size(), + [=] XGBOOST_DEVICE(std::size_t i) mutable { auto g = dh::SegmentId(d_gptr, i); auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); // Normalization if (sum_lambda > 0.0) { double norm = std::log2(1.0 + sum_lambda) / sum_lambda; - d_gpair[i] *= norm; + d_gpair(i, 0) *= norm; } - d_gpair[i] *= (d_weights[g] * w_norm); + d_gpair(i, 0) *= (d_weights[g] * w_norm); }); } @@ -288,7 +288,7 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const linalg::VectorView ti_plus, // input bias ratio linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair) { + linalg::Matrix* out_gpair) { // boilerplate std::int32_t device_id = ctx->gpu_id; dh::safe_cuda(cudaSetDevice(device_id)); @@ -296,8 +296,8 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const info.labels.SetDevice(device_id); preds.SetDevice(device_id); - out_gpair->SetDevice(device_id); - out_gpair->Resize(preds.Size()); + out_gpair->SetDevice(ctx->Device()); + out_gpair->Reshape(preds.Size(), 1); CHECK(p_cache); @@ -308,8 +308,9 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const auto label = info.labels.View(ctx->gpu_id); auto predts = preds.ConstDeviceSpan(); - auto gpairs = out_gpair->DeviceSpan(); - thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.data(), gpairs.size(), GradientPair{0.0f, 0.0f}); + auto gpairs = out_gpair->View(ctx->Device()); + thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.Values().data(), gpairs.Size(), + GradientPair{0.0f, 0.0f}); auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr(); auto const d_gptr = p_cache->DataGroupPtr(ctx); @@ -371,7 +372,7 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, linalg::VectorView ti_plus, // input bias ratio linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair) { + linalg::Matrix* out_gpair) { // boilerplate std::int32_t device_id = ctx->gpu_id; dh::safe_cuda(cudaSetDevice(device_id)); @@ -440,7 +441,7 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, linalg::VectorView ti_plus, // input bias ratio linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair) { + linalg::Matrix* out_gpair) { std::int32_t device_id = ctx->gpu_id; dh::safe_cuda(cudaSetDevice(device_id)); @@ -479,7 +480,7 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, linalg::VectorView ti_plus, // input bias ratio linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair) { + linalg::Matrix* out_gpair) { std::int32_t device_id = ctx->gpu_id; dh::safe_cuda(cudaSetDevice(device_id)); diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh index be9f479cea3b..2e5724f7f1fd 100644 --- a/src/objective/lambdarank_obj.cuh +++ b/src/objective/lambdarank_obj.cuh @@ -61,7 +61,7 @@ struct KernelInputs { linalg::MatrixView labels; common::Span predts; - common::Span gpairs; + linalg::MatrixView gpairs; linalg::VectorView d_roundings; double const *d_cost_rounding; @@ -79,8 +79,8 @@ struct MakePairsOp { /** * \brief Make pair for the topk pair method. */ - XGBOOST_DEVICE std::tuple WithTruncation(std::size_t idx, - bst_group_t g) const { + [[nodiscard]] XGBOOST_DEVICE std::tuple WithTruncation( + std::size_t idx, bst_group_t g) const { auto thread_group_begin = args.d_threads_group_ptr[g]; auto idx_in_thread_group = idx - thread_group_begin; diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h index c2222c028582..f3856e3cecd3 100644 --- a/src/objective/lambdarank_obj.h +++ b/src/objective/lambdarank_obj.h @@ -154,7 +154,7 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, linalg::VectorView t_plus, // input bias ratio linalg::VectorView t_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair); + linalg::Matrix* out_gpair); /** * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. @@ -168,7 +168,7 @@ void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, linalg::VectorView t_plus, // input bias ratio linalg::VectorView t_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair); + linalg::Matrix* out_gpair); void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, HostDeviceVector const& predt, const MetaInfo& info, @@ -176,7 +176,7 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, linalg::VectorView ti_plus, // input bias ratio linalg::VectorView tj_minus, // input bias ratio linalg::VectorView li, linalg::VectorView lj, - HostDeviceVector* out_gpair); + linalg::Matrix* out_gpair); void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, linalg::VectorView lj_full, diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 312992ec59f2..7c762ed48ebc 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2015-2022 by XGBoost Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file multi_class.cc * \brief Definition of multi-class classification objectives. * \author Tianqi Chen @@ -48,13 +48,8 @@ class SoftmaxMultiClassObj : public ObjFunction { ObjInfo Task() const override { return ObjInfo::kClassification; } - void GetGradient(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair) override { - // Remove unused parameter compiler warning. - (void) iter; - + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, std::int32_t, + linalg::Matrix* out_gpair) override { if (info.labels.Size() == 0) { return; } @@ -77,7 +72,7 @@ class SoftmaxMultiClassObj : public ObjFunction { label_correct_.Resize(1); label_correct_.SetDevice(device); - out_gpair->Resize(preds.Size()); + out_gpair->Reshape(info.num_row_, static_cast(nclass)); label_correct_.Fill(1); const bool is_null_weight = info.weights_.Size() == 0; @@ -115,7 +110,7 @@ class SoftmaxMultiClassObj : public ObjFunction { gpair[idx * nclass + k] = GradientPair(p * wt, h); } }, common::Range{0, ndata}, ctx_->Threads(), device) - .Eval(out_gpair, info.labels.Data(), &preds, &info.weights_, &label_correct_); + .Eval(out_gpair->Data(), info.labels.Data(), &preds, &info.weights_, &label_correct_); std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index f94b5edf0494..0774223e7ce9 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -27,13 +27,12 @@ #endif // defined(XGBOOST_USE_CUDA) -namespace xgboost { -namespace obj { +namespace xgboost::obj { class QuantileRegression : public ObjFunction { common::QuantileLossParam param_; HostDeviceVector alpha_; - bst_target_t Targets(MetaInfo const& info) const override { + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { auto const& alpha = param_.quantile_alpha.Get(); CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured."; if (info.ShouldHaveLabels()) { @@ -50,7 +49,7 @@ class QuantileRegression : public ObjFunction { public: void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, std::int32_t iter, - HostDeviceVector* out_gpair) override { + linalg::Matrix* out_gpair) override { if (iter == 0) { CheckInitInputs(info); } @@ -65,10 +64,11 @@ class QuantileRegression : public ObjFunction { auto labels = info.labels.View(ctx_->gpu_id); - out_gpair->SetDevice(ctx_->gpu_id); - out_gpair->Resize(n_targets * info.num_row_); - auto gpair = - linalg::MakeTensorView(ctx_, out_gpair, info.num_row_, n_alphas, n_targets / n_alphas); + out_gpair->SetDevice(ctx_->Device()); + CHECK_EQ(info.labels.Shape(1), 1) + << "Multi-target for quantile regression is not yet supported."; + out_gpair->Reshape(info.num_row_, n_targets); + auto gpair = out_gpair->View(ctx_->Device()); info.weights_.SetDevice(ctx_->gpu_id); common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() @@ -85,15 +85,16 @@ class QuantileRegression : public ObjFunction { ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { auto [sample_id, quantile_id, target_id] = linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size()); + assert(target_id == 0); auto d = predt(i) - labels(sample_id, target_id); auto h = weight[sample_id]; if (d >= 0) { auto g = (1.0f - alpha[quantile_id]) * weight[sample_id]; - gpair(sample_id, quantile_id, target_id) = GradientPair{g, h}; + gpair(sample_id, quantile_id) = GradientPair{g, h}; } else { auto g = (-alpha[quantile_id] * weight[sample_id]); - gpair(sample_id, quantile_id, target_id) = GradientPair{g, h}; + gpair(sample_id, quantile_id) = GradientPair{g, h}; } }); } @@ -192,7 +193,7 @@ class QuantileRegression : public ObjFunction { param_.Validate(); this->alpha_.HostVector() = param_.quantile_alpha.Get(); } - ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } + [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } static char const* Name() { return "reg:quantileerror"; } void SaveConfig(Json* p_out) const override { @@ -206,8 +207,8 @@ class QuantileRegression : public ObjFunction { alpha_.HostVector() = param_.quantile_alpha.Get(); } - const char* DefaultEvalMetric() const override { return "quantile"; } - Json DefaultMetricConfig() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return "quantile"; } + [[nodiscard]] Json DefaultMetricConfig() const override { CHECK(param_.GetInitialised()); Json config{Object{}}; config["name"] = String{this->DefaultEvalMetric()}; @@ -223,5 +224,4 @@ XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name()) #if defined(XGBOOST_USE_CUDA) DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu); #endif // defined(XGBOOST_USE_CUDA) -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 4c5ed9ec81a8..3c431ea682e1 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -36,12 +36,12 @@ #include "xgboost/tree_model.h" // RegTree #if defined(XGBOOST_USE_CUDA) +#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/device_helpers.cuh" #include "../common/linalg_op.cuh" #endif // defined(XGBOOST_USE_CUDA) -namespace xgboost { -namespace obj { +namespace xgboost::obj { namespace { void CheckRegInputs(MetaInfo const& info, HostDeviceVector const& preds) { CheckInitInputs(info); @@ -68,33 +68,60 @@ class RegLossObj : public FitIntercept { HostDeviceVector additional_input_; public: - // 0 - label_correct flag, 1 - scale_pos_weight, 2 - is_null_weight - RegLossObj(): additional_input_(3) {} + void ValidateLabel(MetaInfo const& info) { + auto label = info.labels.View(ctx_->Ordinal()); + auto valid = ctx_->DispatchDevice( + [&] { + return std::all_of(linalg::cbegin(label), linalg::cend(label), + [](float y) -> bool { return Loss::CheckLabel(y); }); + }, + [&] { +#if defined(XGBOOST_USE_CUDA) + auto cuctx = ctx_->CUDACtx(); + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> bool { + auto [m, n] = linalg::UnravelIndex(i, label.Shape()); + return Loss::CheckLabel(label(m, n)); + }); + return dh::Reduce(cuctx->CTP(), it, it + label.Size(), true, thrust::logical_and<>{}); +#else + common::AssertGPUSupport(); + return false; +#endif // defined(XGBOOST_USE_CUDA) + }); + if (!valid) { + LOG(FATAL) << Loss::LabelErrorMsg(); + } + } + // 0 - scale_pos_weight, 1 - is_null_weight + RegLossObj(): additional_input_(2) {} void Configure(const std::vector >& args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return Loss::Info(); } + [[nodiscard]] ObjInfo Task() const override { return Loss::Info(); } - bst_target_t Targets(MetaInfo const& info) const override { + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { // Multi-target regression. - return std::max(static_cast(1), info.labels.Shape(1)); + return std::max(static_cast(1), info.labels.Shape(1)); } - void GetGradient(const HostDeviceVector& preds, - const MetaInfo &info, int, - HostDeviceVector* out_gpair) override { + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, + std::int32_t iter, linalg::Matrix* out_gpair) override { CheckRegInputs(info, preds); + if (iter == 0) { + ValidateLabel(info); + } + size_t const ndata = preds.Size(); - out_gpair->Resize(ndata); + out_gpair->SetDevice(ctx_->Device()); auto device = ctx_->gpu_id; - additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag bool is_null_weight = info.weights_.Size() == 0; auto scale_pos_weight = param_.scale_pos_weight; - additional_input_.HostVector().begin()[1] = scale_pos_weight; - additional_input_.HostVector().begin()[2] = is_null_weight; + additional_input_.HostVector().begin()[0] = scale_pos_weight; + additional_input_.HostVector().begin()[1] = is_null_weight; const size_t nthreads = ctx_->Threads(); bool on_device = device >= 0; @@ -102,7 +129,8 @@ class RegLossObj : public FitIntercept { // for better performance. const size_t n_data_blocks = std::max(static_cast(1), (on_device ? ndata : nthreads)); const size_t block_size = ndata / n_data_blocks + !!(ndata % n_data_blocks); - auto const n_targets = std::max(info.labels.Shape(1), static_cast(1)); + auto const n_targets = this->Targets(info); + out_gpair->Reshape(info.num_row_, n_targets); common::Transform<>::Init( [block_size, ndata, n_targets] XGBOOST_DEVICE( @@ -117,8 +145,8 @@ class RegLossObj : public FitIntercept { GradientPair* out_gpair_ptr = _out_gpair.data(); const size_t begin = data_block_idx*block_size; const size_t end = std::min(ndata, begin + block_size); - const float _scale_pos_weight = _additional_input[1]; - const bool _is_null_weight = _additional_input[2]; + const float _scale_pos_weight = _additional_input[0]; + const bool _is_null_weight = _additional_input[1]; for (size_t idx = begin; idx < end; ++idx) { bst_float p = Loss::PredTransform(preds_ptr[idx]); @@ -127,16 +155,12 @@ class RegLossObj : public FitIntercept { if (label == 1.0f) { w *= _scale_pos_weight; } - if (!Loss::CheckLabel(label)) { - // If there is an incorrect label, the host code will know. - _additional_input[0] = 0; - } out_gpair_ptr[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w); } }, common::Range{0, static_cast(n_data_blocks)}, nthreads, device) - .Eval(&additional_input_, out_gpair, &preds, info.labels.Data(), + .Eval(&additional_input_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); auto const flag = additional_input_.HostVector().begin()[0]; @@ -146,7 +170,7 @@ class RegLossObj : public FitIntercept { } public: - const char* DefaultEvalMetric() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return Loss::DefaultEvalMetric(); } @@ -160,7 +184,7 @@ class RegLossObj : public FitIntercept { .Eval(io_preds); } - float ProbToMargin(float base_score) const override { + [[nodiscard]] float ProbToMargin(float base_score) const override { return Loss::ProbToMargin(base_score); } @@ -215,21 +239,21 @@ class PseudoHuberRegression : public FitIntercept { public: void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return ObjInfo::kRegression; } - bst_target_t Targets(MetaInfo const& info) const override { - return std::max(static_cast(1), info.labels.Shape(1)); + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { + return std::max(static_cast(1), info.labels.Shape(1)); } void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, int /*iter*/, - HostDeviceVector* out_gpair) override { + linalg::Matrix* out_gpair) override { CheckRegInputs(info, preds); auto slope = param_.huber_slope; CHECK_NE(slope, 0.0) << "slope for pseudo huber cannot be 0."; auto labels = info.labels.View(ctx_->gpu_id); out_gpair->SetDevice(ctx_->gpu_id); - out_gpair->Resize(info.labels.Size()); - auto gpair = linalg::MakeVec(out_gpair); + out_gpair->Reshape(info.num_row_, this->Targets(info)); + auto gpair = out_gpair->View(ctx_->Device()); preds.SetDevice(ctx_->gpu_id); auto predt = linalg::MakeVec(&preds); @@ -252,7 +276,7 @@ class PseudoHuberRegression : public FitIntercept { }); } - const char* DefaultEvalMetric() const override { return "mphe"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { return "mphe"; } void SaveConfig(Json* p_out) const override { auto& out = *p_out; @@ -292,15 +316,15 @@ class PoissonRegression : public FitIntercept { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return ObjInfo::kRegression; } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector& preds, - const MetaInfo &info, int, - HostDeviceVector *out_gpair) override { + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int, + linalg::Matrix* out_gpair) override { CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; size_t const ndata = preds.Size(); - out_gpair->Resize(ndata); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, this->Targets(info)); auto device = ctx_->gpu_id; label_correct_.Resize(1); label_correct_.Fill(1); @@ -328,7 +352,7 @@ class PoissonRegression : public FitIntercept { expf(p + max_delta_step) * w}; }, common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( - &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); + &label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { @@ -349,10 +373,10 @@ class PoissonRegression : public FitIntercept { void EvalTransform(HostDeviceVector *io_preds) override { PredTransform(io_preds); } - bst_float ProbToMargin(bst_float base_score) const override { + [[nodiscard]] float ProbToMargin(bst_float base_score) const override { return std::log(base_score); } - const char* DefaultEvalMetric() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return "poisson-nloglik"; } @@ -383,16 +407,15 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") class CoxRegression : public FitIntercept { public: void Configure(Args const&) override {} - ObjInfo Task() const override { return ObjInfo::kRegression; } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector& preds, - const MetaInfo &info, int, - HostDeviceVector *out_gpair) override { + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int, + linalg::Matrix* out_gpair) override { CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const auto& preds_h = preds.HostVector(); - out_gpair->Resize(preds_h.size()); - auto& gpair = out_gpair->HostVector(); + out_gpair->Reshape(info.num_row_, this->Targets(info)); + auto gpair = out_gpair->HostView(); const std::vector &label_order = info.LabelAbsSort(ctx_); const omp_ulong ndata = static_cast(preds_h.size()); // NOLINT(*) @@ -440,8 +463,8 @@ class CoxRegression : public FitIntercept { } const double grad = exp_p*r_k - static_cast(y > 0); - const double hess = exp_p*r_k - exp_p*exp_p * s_k; - gpair.at(ind) = GradientPair(grad * w, hess * w); + const double hess = exp_p * r_k - exp_p * exp_p * s_k; + gpair(ind) = GradientPair(grad * w, hess * w); last_abs_y = abs_y; last_exp_p = exp_p; @@ -457,10 +480,10 @@ class CoxRegression : public FitIntercept { void EvalTransform(HostDeviceVector *io_preds) override { PredTransform(io_preds); } - bst_float ProbToMargin(bst_float base_score) const override { + [[nodiscard]] float ProbToMargin(bst_float base_score) const override { return std::log(base_score); } - const char* DefaultEvalMetric() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return "cox-nloglik"; } @@ -480,16 +503,16 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") class GammaRegression : public FitIntercept { public: void Configure(Args const&) override {} - ObjInfo Task() const override { return ObjInfo::kRegression; } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector &preds, - const MetaInfo &info, int, - HostDeviceVector *out_gpair) override { + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, std::int32_t, + linalg::Matrix* out_gpair) override { CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const size_t ndata = preds.Size(); auto device = ctx_->gpu_id; - out_gpair->Resize(ndata); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, this->Targets(info)); label_correct_.Resize(1); label_correct_.Fill(1); @@ -514,7 +537,7 @@ class GammaRegression : public FitIntercept { _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); }, common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( - &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); + &label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -536,10 +559,10 @@ class GammaRegression : public FitIntercept { void EvalTransform(HostDeviceVector *io_preds) override { PredTransform(io_preds); } - bst_float ProbToMargin(bst_float base_score) const override { + [[nodiscard]] float ProbToMargin(bst_float base_score) const override { return std::log(base_score); } - const char* DefaultEvalMetric() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return "gamma-nloglik"; } void SaveConfig(Json* p_out) const override { @@ -578,15 +601,15 @@ class TweedieRegression : public FitIntercept { metric_ = os.str(); } - ObjInfo Task() const override { return ObjInfo::kRegression; } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector& preds, - const MetaInfo &info, int, - HostDeviceVector *out_gpair) override { + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, std::int32_t, + linalg::Matrix* out_gpair) override { CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; const size_t ndata = preds.Size(); - out_gpair->Resize(ndata); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, this->Targets(info)); auto device = ctx_->gpu_id; label_correct_.Resize(1); @@ -619,7 +642,7 @@ class TweedieRegression : public FitIntercept { _out_gpair[_idx] = GradientPair(grad * w, hess * w); }, common::Range{0, static_cast(ndata), 1}, this->ctx_->Threads(), device) - .Eval(&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); + .Eval(&label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -639,11 +662,11 @@ class TweedieRegression : public FitIntercept { .Eval(io_preds); } - bst_float ProbToMargin(bst_float base_score) const override { + [[nodiscard]] float ProbToMargin(bst_float base_score) const override { return std::log(base_score); } - const char* DefaultEvalMetric() const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return metric_.c_str(); } @@ -672,19 +695,19 @@ XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") class MeanAbsoluteError : public ObjFunction { public: void Configure(Args const&) override {} - ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } - bst_target_t Targets(MetaInfo const& info) const override { - return std::max(static_cast(1), info.labels.Shape(1)); + [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { + return std::max(static_cast(1), info.labels.Shape(1)); } - void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, int /*iter*/, - HostDeviceVector* out_gpair) override { + void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, + std::int32_t /*iter*/, linalg::Matrix* out_gpair) override { CheckRegInputs(info, preds); auto labels = info.labels.View(ctx_->gpu_id); - out_gpair->SetDevice(ctx_->gpu_id); - out_gpair->Resize(info.labels.Size()); - auto gpair = linalg::MakeVec(out_gpair); + out_gpair->SetDevice(ctx_->Device()); + out_gpair->Reshape(info.num_row_, this->Targets(info)); + auto gpair = out_gpair->View(ctx_->Device()); preds.SetDevice(ctx_->gpu_id); auto predt = linalg::MakeVec(&preds); @@ -692,14 +715,14 @@ class MeanAbsoluteError : public ObjFunction { common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() : info.weights_.ConstDeviceSpan()}; - linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { + linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable { auto sign = [](auto x) { return (x > static_cast(0)) - (x < static_cast(0)); }; - auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); + auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape()); auto grad = sign(predt(i) - y) * weight[sample_id]; auto hess = weight[sample_id]; - gpair(i) = GradientPair{grad, hess}; + gpair(sample_id, target_id) = GradientPair{grad, hess}; }); } @@ -748,7 +771,7 @@ class MeanAbsoluteError : public ObjFunction { p_tree); } - const char* DefaultEvalMetric() const override { return "mae"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { return "mae"; } void SaveConfig(Json* p_out) const override { auto& out = *p_out; @@ -763,5 +786,4 @@ class MeanAbsoluteError : public ObjFunction { XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") .describe("Mean absoluate error.") .set_body([]() { return new MeanAbsoluteError(); }); -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 3533de772f59..ec654a1b2fdc 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -66,14 +66,13 @@ inline void FitStump(Context const*, linalg::TensorView, #endif // !defined(XGBOOST_USE_CUDA) } // namespace cuda_impl -void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector const& gpair, +void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix const& gpair, bst_target_t n_targets, linalg::Vector* out) { out->SetDevice(ctx->gpu_id); out->Reshape(n_targets); - auto n_samples = gpair.Size() / n_targets; - gpair.SetDevice(ctx->gpu_id); - auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets); + gpair.SetDevice(ctx->Device()); + auto gpair_t = gpair.View(ctx->Device()); ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); } diff --git a/src/tree/fit_stump.h b/src/tree/fit_stump.h index 4778ecfc5dec..2af779f77c46 100644 --- a/src/tree/fit_stump.h +++ b/src/tree/fit_stump.h @@ -31,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) { /** * @brief Fit a tree stump as an estimation of base_score. */ -void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector const& gpair, +void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix const& gpair, bst_target_t n_targets, linalg::Vector* out); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 2110cd6e6b06..17e020ced5fc 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -269,17 +269,18 @@ class GlobalApproxUpdater : public TreeUpdater { out["hist_train_param"] = ToJson(hist_param_); } - void InitData(TrainParam const ¶m, HostDeviceVector const *gpair, + void InitData(TrainParam const ¶m, linalg::Matrix const *gpair, linalg::Matrix *sampled) { *sampled = linalg::Empty(ctx_, gpair->Size(), 1); - sampled->Data()->Copy(*gpair); + auto in = gpair->HostView().Values(); + std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data()); SampleGradient(ctx_, param, sampled->HostView()); } [[nodiscard]] char const *Name() const override { return "grow_histmaker"; } - void Update(TrainParam const *param, HostDeviceVector *gpair, DMatrix *m, + void Update(TrainParam const *param, linalg::Matrix *gpair, DMatrix *m, common::Span> out_position, const std::vector &trees) override { CHECK(hist_param_.GetInitialised()); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index bda9b4dfa3bc..3afbe3e46bdd 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -91,7 +91,7 @@ class ColMaker: public TreeUpdater { } } - void Update(TrainParam const *param, HostDeviceVector *gpair, DMatrix *dmat, + void Update(TrainParam const *param, linalg::Matrix *gpair, DMatrix *dmat, common::Span> /*out_position*/, const std::vector &trees) override { if (collective::IsDistributed()) { @@ -106,10 +106,11 @@ class ColMaker: public TreeUpdater { // rescale learning rate according to size of trees interaction_constraints_.Configure(*param, dmat->Info().num_row_); // build tree + CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented(); for (auto tree : trees) { CHECK(ctx_); Builder builder(*param, colmaker_param_, interaction_constraints_, ctx_, column_densities_); - builder.Update(gpair->ConstHostVector(), dmat, tree); + builder.Update(gpair->Data()->ConstHostVector(), dmat, tree); } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 5cce89e2cd57..e0d221362ebb 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -759,16 +759,18 @@ class GPUHistMaker : public TreeUpdater { dh::GlobalMemoryLogger().Log(); } - void Update(TrainParam const* param, HostDeviceVector* gpair, DMatrix* dmat, + void Update(TrainParam const* param, linalg::Matrix* gpair, DMatrix* dmat, common::Span> out_position, const std::vector& trees) override { monitor_.Start("Update"); + CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented(); + auto gpair_hdv = gpair->Data(); // build tree try { std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { - this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]); + this->UpdateTree(param, gpair_hdv, dmat, tree, &out_position[t_idx]); this->hist_maker_param_.CheckTreesSynchronized(tree); ++t_idx; } @@ -886,7 +888,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { } ~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); } - void Update(TrainParam const* param, HostDeviceVector* gpair, DMatrix* p_fmat, + void Update(TrainParam const* param, linalg::Matrix* gpair, DMatrix* p_fmat, common::Span> out_position, const std::vector& trees) override { monitor_.Start("Update"); @@ -897,7 +899,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { auto hess = dh::ToSpan(hess_); gpair->SetDevice(ctx_->Device()); - auto d_gpair = gpair->ConstDeviceSpan(); + auto d_gpair = gpair->Data()->ConstDeviceSpan(); auto cuctx = ctx_->CUDACtx(); thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess), [=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); }); @@ -911,7 +913,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { - this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]); + this->UpdateTree(gpair->Data(), p_fmat, tree, &out_position[t_idx]); this->hist_maker_param_.CheckTreesSynchronized(tree); ++t_idx; } diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 29f9917ba1e0..2c2d1a2f0d93 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -31,7 +31,7 @@ class TreePruner : public TreeUpdater { [[nodiscard]] bool CanModifyTree() const override { return true; } // update the tree, do pruning - void Update(TrainParam const* param, HostDeviceVector* gpair, DMatrix* p_fmat, + void Update(TrainParam const* param, linalg::Matrix* gpair, DMatrix* p_fmat, common::Span> out_position, const std::vector& trees) override { pruner_monitor_.Start("PrunerUpdate"); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 883c18f361fc..34890c2e5326 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -492,7 +492,7 @@ class QuantileHistMaker : public TreeUpdater { [[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; } - void Update(TrainParam const *param, HostDeviceVector *gpair, DMatrix *p_fmat, + void Update(TrainParam const *param, linalg::Matrix *gpair, DMatrix *p_fmat, common::Span> out_position, const std::vector &trees) override { if (trees.front()->IsMultiTarget()) { @@ -511,8 +511,7 @@ class QuantileHistMaker : public TreeUpdater { } bst_target_t n_targets = trees.front()->NumTargets(); - auto h_gpair = - linalg::MakeTensorView(ctx_, gpair->HostSpan(), p_fmat->Info().num_row_, n_targets); + auto h_gpair = gpair->HostView(); linalg::Matrix sample_out; auto h_sample_out = h_gpair; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 2bfd3c8defb9..941df7aec491 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -31,11 +31,14 @@ class TreeRefresher : public TreeUpdater { [[nodiscard]] char const *Name() const override { return "refresh"; } [[nodiscard]] bool CanModifyTree() const override { return true; } // update the tree, do pruning - void Update(TrainParam const *param, HostDeviceVector *gpair, DMatrix *p_fmat, + void Update(TrainParam const *param, linalg::Matrix *gpair, DMatrix *p_fmat, common::Span> /*out_position*/, const std::vector &trees) override { - if (trees.size() == 0) return; - const std::vector &gpair_h = gpair->ConstHostVector(); + if (trees.size() == 0) { + return; + } + CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented(); + const std::vector &gpair_h = gpair->Data()->ConstHostVector(); // thread temporal space std::vector > stemp; std::vector fvec_temp; diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 2422807e2f30..f64f354837f6 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -31,7 +31,7 @@ class TreeSyncher : public TreeUpdater { [[nodiscard]] char const* Name() const override { return "prune"; } - void Update(TrainParam const*, HostDeviceVector*, DMatrix*, + void Update(TrainParam const*, linalg::Matrix*, DMatrix*, common::Span> /*out_position*/, const std::vector& trees) override { if (collective::GetWorldSize() == 1) return; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 3bf03c95581a..7fcab199ed08 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -565,7 +565,7 @@ void TestXGDMatrixGetQuantileCut(Context const *ctx) { ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0); ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0); if (ctx->IsCUDA()) { - ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0); + ASSERT_EQ(XGBoosterSetParam(booster, "device", ctx->DeviceName().c_str()), 0); } ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0); ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); @@ -596,7 +596,7 @@ void TestXGDMatrixGetQuantileCut(Context const *ctx) { ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0); ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0); if (ctx->IsCUDA()) { - ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0); + ASSERT_EQ(XGBoosterSetParam(booster, "device", ctx->DeviceName().c_str()), 0); } ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0); ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 9e6311701e58..d7b7e588d11c 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -65,7 +65,9 @@ TEST(GBTree, PredictionCache) { gbtree.Configure({{"tree_method", "hist"}}); auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); - auto gpair = GenerateRandomGradients(kRows); + linalg::Matrix gpair({kRows}, ctx.Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); + PredictionCacheEntry out_predictions; gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr); @@ -213,7 +215,8 @@ TEST(GBTree, ChooseTreeMethod) { } learner->Configure(); for (std::int32_t i = 0; i < 3; ++i) { - HostDeviceVector gpair{GenerateRandomGradients(Xy->Info().num_row_)}; + linalg::Matrix gpair{{Xy->Info().num_row_}, Context::kCpuId}; + gpair.Data()->Copy(GenerateRandomGradients(Xy->Info().num_row_)); learner->BoostOneIter(0, Xy, &gpair); } diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 111c7b30e96a..a9ff347ea5ef 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -96,9 +96,9 @@ void CheckObjFunctionImpl(std::unique_ptr const& obj, std::vector out_grad, std::vector out_hess) { xgboost::HostDeviceVector in_preds(preds); - xgboost::HostDeviceVector out_gpair; - obj->GetGradient(in_preds, info, 1, &out_gpair); - std::vector& gpair = out_gpair.HostVector(); + xgboost::linalg::Matrix out_gpair; + obj->GetGradient(in_preds, info, 0, &out_gpair); + std::vector& gpair = out_gpair.Data()->HostVector(); ASSERT_EQ(gpair.size(), in_preds.Size()); for (int i = 0; i < static_cast(gpair.size()); ++i) { @@ -119,8 +119,8 @@ void CheckObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels = - xgboost::linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; + info.labels = xgboost::linalg::Tensor{ + labels.cbegin(), labels.cend(), {labels.size(), static_cast(1)}, -1}; info.weights_.HostVector() = weights; CheckObjFunctionImpl(obj, preds, labels, weights, info, out_grad, out_hess); @@ -155,8 +155,8 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels = xgboost::linalg::Tensor{ - labels.cbegin(), labels.cend(), {labels.size(), static_cast(1)}, -1}; + info.labels = xgboost::linalg::Matrix{ + labels.cbegin(), labels.cend(), {labels.size(), static_cast(1)}, -1}; info.weights_.HostVector() = weights; info.group_ptr_ = groups; @@ -645,11 +645,10 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, } p_dmat->Info().labels = linalg::Tensor{labels.cbegin(), labels.cend(), {labels.size()}, -1}; - HostDeviceVector gpair; - auto& h_gpair = gpair.HostVector(); - h_gpair.resize(kRows); + linalg::Matrix gpair({kRows}, ctx->Ordinal()); + auto h_gpair = gpair.HostView(); for (size_t i = 0; i < kRows; ++i) { - h_gpair[i] = GradientPair{static_cast(i), 1}; + h_gpair(i) = GradientPair{static_cast(i), 1}; } PredictionCacheEntry predts; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index e39375dfabc1..bad15c69591f 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -387,23 +387,6 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, LearnerModelParam const* learner_model_param, Context const* generic_param); -inline std::unique_ptr> GenerateGradients( - std::size_t rows, bst_target_t n_targets = 1) { - auto p_gradients = std::make_unique>(rows * n_targets); - auto& h_gradients = p_gradients->HostVector(); - - xgboost::SimpleLCG gen; - xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - - for (std::size_t i = 0; i < rows * n_targets; ++i) { - auto grad = dist(&gen); - auto hess = dist(&gen); - h_gradients[i] = GradientPair{grad, hess}; - } - - return p_gradients; -} - /** * \brief Make a context that uses CUDA if device >= 0. */ @@ -415,11 +398,12 @@ inline Context MakeCUDACtx(std::int32_t device) { } inline HostDeviceVector GenerateRandomGradients(const size_t n_rows, - float lower= 0.0f, float upper = 1.0f) { + float lower = 0.0f, + float upper = 1.0f) { xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(lower, upper); std::vector h_gpair(n_rows); - for (auto &gpair : h_gpair) { + for (auto& gpair : h_gpair) { bst_float grad = dist(&gen); bst_float hess = dist(&gen); gpair = GradientPair(grad, hess); @@ -428,6 +412,16 @@ inline HostDeviceVector GenerateRandomGradients(const size_t n_row return gpair; } +inline linalg::Matrix GenerateRandomGradients(Context const* ctx, bst_row_t n_rows, + bst_target_t n_targets, + float lower = 0.0f, + float upper = 1.0f) { + auto g = GenerateRandomGradients(n_rows * n_targets, lower, upper); + linalg::Matrix gpair({n_rows, static_cast(n_targets)}, ctx->Device()); + gpair.Data()->Copy(g); + return gpair; +} + typedef void *DMatrixHandle; // NOLINT(*); class ArrayIterForTest { diff --git a/tests/cpp/linear/test_linear.cc b/tests/cpp/linear/test_linear.cc index 6b2d17e1028e..8f81428b3b4b 100644 --- a/tests/cpp/linear/test_linear.cc +++ b/tests/cpp/linear/test_linear.cc @@ -24,8 +24,8 @@ TEST(Linear, Shotgun) { auto updater = std::unique_ptr(xgboost::LinearUpdater::Create("shotgun", &ctx)); updater->Configure({{"eta", "1."}}); - xgboost::HostDeviceVector gpair( - p_fmat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); + linalg::Matrix gpair{ + linalg::Constant(&ctx, xgboost::GradientPair(-5, 1.0), p_fmat->Info().num_row_, 1)}; xgboost::gbm::GBLinearModel model{&mparam}; model.LazyInitModel(); updater->Update(&gpair, p_fmat.get(), &model, gpair.Size()); @@ -55,8 +55,8 @@ TEST(Linear, coordinate) { auto updater = std::unique_ptr( xgboost::LinearUpdater::Create("coord_descent", &ctx)); updater->Configure({{"eta", "1."}}); - xgboost::HostDeviceVector gpair( - p_fmat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); + linalg::Matrix gpair{ + linalg::Constant(&ctx, xgboost::GradientPair(-5, 1.0), p_fmat->Info().num_row_, 1)}; xgboost::gbm::GBLinearModel model{&mparam}; model.LazyInitModel(); updater->Update(&gpair, p_fmat.get(), &model, gpair.Size()); diff --git a/tests/cpp/linear/test_linear.cu b/tests/cpp/linear/test_linear.cu index 6a2a6ef8c042..8475116bc954 100644 --- a/tests/cpp/linear/test_linear.cu +++ b/tests/cpp/linear/test_linear.cu @@ -1,4 +1,6 @@ -// Copyright by Contributors +/** + * Copyright 2018-2023, XGBoost Contributors + */ #include #include @@ -19,8 +21,7 @@ TEST(Linear, GPUCoordinate) { auto updater = std::unique_ptr( xgboost::LinearUpdater::Create("gpu_coord_descent", &ctx)); updater->Configure({{"eta", "1."}}); - xgboost::HostDeviceVector gpair( - mat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); + auto gpair = linalg::Constant(&ctx, xgboost::GradientPair(-5, 1.0), mat->Info().num_row_, 1); xgboost::gbm::GBLinearModel model{&mparam}; model.LazyInitModel(); diff --git a/tests/cpp/objective/test_aft_obj.cc b/tests/cpp/objective/test_aft_obj.cc index 74973918c10b..972dfc53f58e 100644 --- a/tests/cpp/objective/test_aft_obj.cc +++ b/tests/cpp/objective/test_aft_obj.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) by Contributors 2020 +/** + * Copyright 2020-2023, XGBoost Contributors */ #include #include @@ -12,9 +12,7 @@ #include "../helpers.h" #include "../../../src/common/survival_util.h" -namespace xgboost { -namespace common { - +namespace xgboost::common { TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) { auto ctx = MakeCUDACtx(GPUIDX); std::unique_ptr objective(ObjFunction::Create("survival:aft", &ctx)); @@ -65,14 +63,14 @@ static inline void CheckGPairOverGridPoints( preds[i] = std::log(std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low)); } - HostDeviceVector out_gpair; + linalg::Matrix out_gpair; obj->GetGradient(HostDeviceVector(preds), info, 1, &out_gpair); - const auto& gpair = out_gpair.HostVector(); + const auto gpair = out_gpair.HostView(); CHECK_EQ(num_point, expected_grad.size()); CHECK_EQ(num_point, expected_hess.size()); for (int i = 0; i < num_point; ++i) { - EXPECT_NEAR(gpair[i].GetGrad(), expected_grad[i], ftol); - EXPECT_NEAR(gpair[i].GetHess(), expected_hess[i], ftol); + EXPECT_NEAR(gpair(i).GetGrad(), expected_grad[i], ftol); + EXPECT_NEAR(gpair(i).GetHess(), expected_hess[i], ftol); } } @@ -169,5 +167,4 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) { 0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f }); } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index c808e97f0c75..963f6963969b 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -74,35 +74,35 @@ void TestNDCGGPair(Context const* ctx) { info.labels = linalg::Tensor{{0, 1, 0, 1}, {4, 1}, GPUIDX}; info.group_ptr_ = {0, 2, 4}; info.num_row_ = 4; - HostDeviceVector gpairs; + linalg::Matrix gpairs; obj->GetGradient(predts, info, 0, &gpairs); ASSERT_EQ(gpairs.Size(), predts.Size()); { predts = {1, 0, 1, 0}; - HostDeviceVector gpairs; + linalg::Matrix gpairs; obj->GetGradient(predts, info, 0, &gpairs); - for (size_t i = 0; i < gpairs.Size(); ++i) { - ASSERT_GT(gpairs.HostSpan()[i].GetHess(), 0); + for (std::size_t i = 0; i < gpairs.Size(); ++i) { + ASSERT_GT(gpairs.HostView()(i).GetHess(), 0); } - ASSERT_LT(gpairs.HostSpan()[1].GetGrad(), 0); - ASSERT_LT(gpairs.HostSpan()[3].GetGrad(), 0); + ASSERT_LT(gpairs.HostView()(1).GetGrad(), 0); + ASSERT_LT(gpairs.HostView()(3).GetGrad(), 0); - ASSERT_GT(gpairs.HostSpan()[0].GetGrad(), 0); - ASSERT_GT(gpairs.HostSpan()[2].GetGrad(), 0); + ASSERT_GT(gpairs.HostView()(0).GetGrad(), 0); + ASSERT_GT(gpairs.HostView()(2).GetGrad(), 0); info.weights_ = {2, 3}; - HostDeviceVector weighted_gpairs; + linalg::Matrix weighted_gpairs; obj->GetGradient(predts, info, 0, &weighted_gpairs); - auto const& h_gpairs = gpairs.ConstHostSpan(); - auto const& h_weighted_gpairs = weighted_gpairs.ConstHostSpan(); + auto const& h_gpairs = gpairs.HostView(); + auto const& h_weighted_gpairs = weighted_gpairs.HostView(); for (size_t i : {0ul, 1ul}) { - ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 2.0f); - ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 2.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetGrad(), h_gpairs(i).GetGrad() * 2.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetHess(), h_gpairs(i).GetHess() * 2.0f); } for (size_t i : {2ul, 3ul}) { - ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 3.0f); - ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 3.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetGrad(), h_gpairs(i).GetGrad() * 3.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs(i).GetHess(), h_gpairs(i).GetHess() * 3.0f); } } @@ -125,7 +125,7 @@ void TestUnbiasedNDCG(Context const* ctx) { std::sort(h_label.begin(), h_label.end(), std::greater<>{}); HostDeviceVector predt(p_fmat->Info().num_row_, 1.0f); - HostDeviceVector out_gpair; + linalg::Matrix out_gpair; obj->GetGradient(predt, p_fmat->Info(), 0, &out_gpair); Json config{Object{}}; diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu index 16dc453079cc..1c13665fcca2 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cu +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -42,20 +42,21 @@ void TestGPUMakePair() { auto d = dummy.View(ctx.gpu_id); linalg::Vector dgpair; auto dg = dgpair.View(ctx.gpu_id); - cuda_impl::KernelInputs args{d, - d, - d, - d, - p_cache->DataGroupPtr(&ctx), - p_cache->CUDAThreadsGroupPtr(), - rank_idx, - info.labels.View(ctx.gpu_id), - predt.ConstDeviceSpan(), - {}, - dg, - nullptr, - y_sorted_idx, - 0}; + cuda_impl::KernelInputs args{ + d, + d, + d, + d, + p_cache->DataGroupPtr(&ctx), + p_cache->CUDAThreadsGroupPtr(), + rank_idx, + info.labels.View(ctx.gpu_id), + predt.ConstDeviceSpan(), + linalg::MatrixView{common::Span{}, {0}, 0}, + dg, + nullptr, + y_sorted_idx, + 0}; return args; }; diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index b8a40603b348..35e8287b645d 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -122,8 +122,8 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f); EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f); EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f); - EXPECT_ANY_THROW(obj->ProbToMargin(10)) - << "Expected error when base_score not in range [0,1f] for LogisticRegression"; + EXPECT_ANY_THROW((void)obj->ProbToMargin(10)) + << "Expected error when base_score not in range [0,1f] for LogisticRegression"; // test PredTransform HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; @@ -282,9 +282,9 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { TEST(Objective, CPU_vs_CUDA) { Context ctx = MakeCUDACtx(GPUIDX); - ObjFunction* obj = ObjFunction::Create("reg:squarederror", &ctx); - HostDeviceVector cpu_out_preds; - HostDeviceVector cuda_out_preds; + std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx)}; + linalg::Matrix cpu_out_preds; + linalg::Matrix cuda_out_preds; constexpr size_t kRows = 400; constexpr size_t kCols = 100; @@ -300,7 +300,7 @@ TEST(Objective, CPU_vs_CUDA) { info.labels.Reshape(kRows); auto& h_labels = info.labels.Data()->HostVector(); for (size_t i = 0; i < h_labels.size(); ++i) { - h_labels[i] = 1 / (float)(i+1); + h_labels[i] = 1 / static_cast(i+1); } { @@ -314,19 +314,17 @@ TEST(Objective, CPU_vs_CUDA) { obj->GetGradient(preds, info, 0, &cuda_out_preds); } - auto& h_cpu_out = cpu_out_preds.HostVector(); - auto& h_cuda_out = cuda_out_preds.HostVector(); + auto h_cpu_out = cpu_out_preds.HostView(); + auto h_cuda_out = cuda_out_preds.HostView(); float sgrad = 0; float shess = 0; for (size_t i = 0; i < kRows; ++i) { - sgrad += std::pow(h_cpu_out[i].GetGrad() - h_cuda_out[i].GetGrad(), 2); - shess += std::pow(h_cpu_out[i].GetHess() - h_cuda_out[i].GetHess(), 2); + sgrad += std::pow(h_cpu_out(i).GetGrad() - h_cuda_out(i).GetGrad(), 2); + shess += std::pow(h_cpu_out(i).GetHess() - h_cuda_out(i).GetHess(), 2); } ASSERT_NEAR(sgrad, 0.0f, kRtEps); ASSERT_NEAR(shess, 0.0f, kRtEps); - - delete obj; } #endif diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index a54c42a98ecb..5ff0fdeecc2c 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -189,11 +189,10 @@ void TestUpdatePredictionCache(bool use_subsampling) { auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); - HostDeviceVector gpair; - auto& h_gpair = gpair.HostVector(); - h_gpair.resize(kRows * kClasses); + linalg::Matrix gpair({kRows, kClasses}, ctx.Device()); + auto h_gpair = gpair.HostView(); for (size_t i = 0; i < kRows * kClasses; ++i) { - h_gpair[i] = {static_cast(i), 1}; + std::apply(h_gpair, linalg::UnravelIndex(i, kRows, kClasses)) = {static_cast(i), 1}; } PredictionCacheEntry predtion_cache; diff --git a/tests/cpp/test_multi_target.cc b/tests/cpp/test_multi_target.cc index c8d371941255..cc81a4ba2ccc 100644 --- a/tests/cpp/test_multi_target.cc +++ b/tests/cpp/test_multi_target.cc @@ -68,10 +68,12 @@ class TestL1MultiTarget : public ::testing::Test { } } - void RunTest(std::string const& tree_method, bool weight) { + void RunTest(Context const* ctx, std::string const& tree_method, bool weight) { auto p_fmat = weight ? Xyw_ : Xy_; std::unique_ptr learner{Learner::Create({p_fmat})}; - learner->SetParams(Args{{"tree_method", tree_method}, {"objective", "reg:absoluteerror"}}); + learner->SetParams(Args{{"tree_method", tree_method}, + {"objective", "reg:absoluteerror"}, + {"device", ctx->DeviceName()}}); learner->Configure(); for (auto i = 0; i < 4; ++i) { learner->UpdateOneIter(i, p_fmat); @@ -87,7 +89,9 @@ class TestL1MultiTarget : public ::testing::Test { for (bst_target_t t{0}; t < p_fmat->Info().labels.Shape(1); ++t) { auto t_Xy = weight ? single_w_[t] : single_[t]; std::unique_ptr sl{Learner::Create({t_Xy})}; - sl->SetParams(Args{{"tree_method", tree_method}, {"objective", "reg:absoluteerror"}}); + sl->SetParams(Args{{"tree_method", tree_method}, + {"objective", "reg:absoluteerror"}, + {"device", ctx->DeviceName()}}); sl->Configure(); sl->UpdateOneIter(0, t_Xy); Json s_config{Object{}}; @@ -104,20 +108,32 @@ class TestL1MultiTarget : public ::testing::Test { ASSERT_FLOAT_EQ(mean, base_score); } - void RunTest(std::string const& tree_method) { - this->RunTest(tree_method, false); - this->RunTest(tree_method, true); + void RunTest(Context const* ctx, std::string const& tree_method) { + this->RunTest(ctx, tree_method, false); + this->RunTest(ctx, tree_method, true); } }; -TEST_F(TestL1MultiTarget, Hist) { this->RunTest("hist"); } +TEST_F(TestL1MultiTarget, Hist) { + Context ctx; + this->RunTest(&ctx, "hist"); +} -TEST_F(TestL1MultiTarget, Exact) { this->RunTest("exact"); } +TEST_F(TestL1MultiTarget, Exact) { + Context ctx; + this->RunTest(&ctx, "exact"); +} -TEST_F(TestL1MultiTarget, Approx) { this->RunTest("approx"); } +TEST_F(TestL1MultiTarget, Approx) { + Context ctx; + this->RunTest(&ctx, "approx"); +} #if defined(XGBOOST_USE_CUDA) -TEST_F(TestL1MultiTarget, GpuHist) { this->RunTest("gpu_hist"); } +TEST_F(TestL1MultiTarget, GpuHist) { + auto ctx = MakeCUDACtx(0); + this->RunTest(&ctx, "hist"); +} #endif // defined(XGBOOST_USE_CUDA) TEST(MultiStrategy, Configure) { diff --git a/tests/cpp/tree/test_fit_stump.cc b/tests/cpp/tree/test_fit_stump.cc index 18511c3a0aa1..d9441fd6f3fd 100644 --- a/tests/cpp/tree/test_fit_stump.cc +++ b/tests/cpp/tree/test_fit_stump.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022 by XGBoost Contributors + * Copyright 2022-2023, XGBoost Contributors */ #include #include @@ -8,17 +8,17 @@ #include "../../src/tree/fit_stump.h" #include "../helpers.h" -namespace xgboost { -namespace tree { +namespace xgboost::tree { namespace { void TestFitStump(Context const *ctx, DataSplitMode split = DataSplitMode::kRow) { std::size_t constexpr kRows = 16, kTargets = 2; - HostDeviceVector gpair; - auto &h_gpair = gpair.HostVector(); - h_gpair.resize(kRows * kTargets); + linalg::Matrix gpair; + gpair.SetDevice(ctx->Device()); + gpair.Reshape(kRows, kTargets); + auto h_gpair = gpair.HostView(); for (std::size_t i = 0; i < kRows; ++i) { for (std::size_t t = 0; t < kTargets; ++t) { - h_gpair.at(i * kTargets + t) = GradientPair{static_cast(i), 1}; + h_gpair(i, t) = GradientPair{static_cast(i), 1}; } } linalg::Vector out; @@ -53,6 +53,4 @@ TEST(InitEstimation, FitStumpColumnSplit) { auto constexpr kWorldSize{3}; RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol); } - -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index dd2d802cac9f..50cdae7413bd 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -214,7 +214,7 @@ TEST(GpuHist, TestHistogramIndex) { TestHistogramIndexImpl(); } -void UpdateTree(Context const* ctx, HostDeviceVector* gpair, DMatrix* dmat, +void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix* dmat, size_t gpu_page_size, RegTree* tree, HostDeviceVector* preds, float subsample = 1.0f, const std::string& sampling_method = "uniform", int max_bin = 2) { @@ -264,7 +264,8 @@ TEST(GpuHist, UniformSampling) { // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); - auto gpair = GenerateRandomGradients(kRows); + linalg::Matrix gpair({kRows}, Context{}.MakeCUDA().Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. RegTree tree; @@ -294,7 +295,8 @@ TEST(GpuHist, GradientBasedSampling) { // Create an in-memory DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); - auto gpair = GenerateRandomGradients(kRows); + linalg::Matrix gpair({kRows}, MakeCUDACtx(0).Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. RegTree tree; @@ -330,11 +332,12 @@ TEST(GpuHist, ExternalMemory) { // Create a single batch DMatrix. std::unique_ptr dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache")); - auto gpair = GenerateRandomGradients(kRows); + Context ctx(MakeCUDACtx(0)); + linalg::Matrix gpair({kRows}, ctx.Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. RegTree tree; - Context ctx(MakeCUDACtx(0)); HostDeviceVector preds(kRows, 0.0, 0); UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows); // Build another tree using multiple ELLPACK pages. @@ -367,12 +370,13 @@ TEST(GpuHist, ExternalMemoryWithSampling) { std::unique_ptr dmat_ext( CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache")); - auto gpair = GenerateRandomGradients(kRows); + Context ctx(MakeCUDACtx(0)); + linalg::Matrix gpair({kRows}, ctx.Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. auto rng = common::GlobalRandom(); - Context ctx(MakeCUDACtx(0)); RegTree tree; HostDeviceVector preds(kRows, 0.0, 0); UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows); diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index d034403392f7..e90120231835 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -26,9 +26,11 @@ TEST(GrowHistMaker, InteractionConstraint) { auto constexpr kRows = 32; auto constexpr kCols = 16; auto p_dmat = GenerateDMatrix(kRows, kCols); - auto p_gradients = GenerateGradients(kRows); - Context ctx; + + linalg::Matrix gpair({kRows}, ctx.Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); + ObjInfo task{ObjInfo::kRegression}; { // With constraints @@ -40,7 +42,7 @@ TEST(GrowHistMaker, InteractionConstraint) { Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); std::vector> position(1); updater->Configure(Args{}); - updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&tree}); + updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree}); ASSERT_EQ(tree.NumExtraNodes(), 4); ASSERT_EQ(tree[0].SplitIndex(), 1); @@ -57,7 +59,7 @@ TEST(GrowHistMaker, InteractionConstraint) { TrainParam param; param.Init(Args{}); updater->Configure(Args{}); - updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&tree}); + updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree}); ASSERT_EQ(tree.NumExtraNodes(), 10); ASSERT_EQ(tree[0].SplitIndex(), 1); @@ -70,9 +72,12 @@ TEST(GrowHistMaker, InteractionConstraint) { namespace { void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical, RegTree const& expected_tree) { - auto p_dmat = GenerateDMatrix(rows, cols, categorical); - auto p_gradients = GenerateGradients(rows); Context ctx; + auto p_dmat = GenerateDMatrix(rows, cols, categorical); + linalg::Matrix gpair({rows}, ctx.Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(rows)); + + ObjInfo task{ObjInfo::kRegression}; std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::vector> position(1); @@ -84,7 +89,7 @@ void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical, TrainParam param; param.Init(Args{}); updater->Configure(Args{}); - updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree}); + updater->Update(¶m, &gpair, sliced.get(), position, {&tree}); Json json{Object{}}; tree.SaveModel(&json); @@ -100,15 +105,16 @@ void TestColumnSplit(bool categorical) { RegTree expected_tree{1u, kCols}; ObjInfo task{ObjInfo::kRegression}; { - auto p_dmat = GenerateDMatrix(kRows, kCols, categorical); - auto p_gradients = GenerateGradients(kRows); Context ctx; + auto p_dmat = GenerateDMatrix(kRows, kCols, categorical); + linalg::Matrix gpair({kRows}, ctx.Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; std::vector> position(1); TrainParam param; param.Init(Args{}); updater->Configure(Args{}); - updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&expected_tree}); + updater->Update(¶m, &gpair, p_dmat.get(), position, {&expected_tree}); } auto constexpr kWorldSize = 2; diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index 0aafb0a4f62c..fc1d0508797c 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -69,7 +69,7 @@ class TestPredictionCache : public ::testing::Test { std::unique_ptr updater{TreeUpdater::Create(updater_name, ctx, &task)}; RegTree tree; std::vector trees{&tree}; - auto gpair = GenerateRandomGradients(n_samples_); + auto gpair = GenerateRandomGradients(ctx, n_samples_, 1); tree::TrainParam param; param.UpdateAllowUnknown(Args{{"max_bin", "64"}}); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 843e2b2ee7b4..1a3ec532e18b 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -21,15 +21,13 @@ TEST(Updater, Prune) { std::vector> cfg; cfg.emplace_back("num_feature", std::to_string(kCols)); cfg.emplace_back("min_split_loss", "10"); + Context ctx; // These data are just place holders. - HostDeviceVector gpair = - { {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, - {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }; - std::shared_ptr p_dmat { - RandomDataGenerator{32, 10, 0}.GenerateDMatrix() }; - - Context ctx; + linalg::Matrix gpair + {{ {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, + {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }, {8, 1}, ctx.Device()}; + std::shared_ptr p_dmat{RandomDataGenerator{32, 10, 0}.GenerateDMatrix()}; // prepare tree RegTree tree = RegTree{1u, kCols}; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 4afea74cee9c..6327703edbcb 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -202,13 +202,13 @@ TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner(3); } namespace { -void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_targets, +void VerifyColumnSplit(Context const* ctx, bst_row_t rows, bst_feature_t cols, bst_target_t n_targets, RegTree const& expected_tree) { auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true); - auto p_gradients = GenerateGradients(rows, n_targets); - Context ctx; + linalg::Matrix gpair = GenerateRandomGradients(ctx, rows, n_targets); + ObjInfo task{ObjInfo::kRegression}; - std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)}; + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker", ctx, &task)}; std::vector> position(1); std::unique_ptr sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())}; @@ -217,7 +217,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_target TrainParam param; param.Init(Args{}); updater->Configure(Args{}); - updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree}); + updater->Update(¶m, &gpair, sliced.get(), position, {&tree}); Json json{Object{}}; tree.SaveModel(&json); @@ -232,21 +232,21 @@ void TestColumnSplit(bst_target_t n_targets) { RegTree expected_tree{n_targets, kCols}; ObjInfo task{ObjInfo::kRegression}; + Context ctx; { auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); - auto p_gradients = GenerateGradients(kRows, n_targets); - Context ctx; + auto gpair = GenerateRandomGradients(&ctx, kRows, n_targets); std::unique_ptr updater{ TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)}; std::vector> position(1); TrainParam param; param.Init(Args{}); updater->Configure(Args{}); - updater->Update(¶m, p_gradients.get(), Xy.get(), position, {&expected_tree}); + updater->Update(¶m, &gpair, Xy.get(), position, {&expected_tree}); } auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, n_targets, + RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, &ctx, kRows, kCols, n_targets, std::cref(expected_tree)); } } // anonymous namespace diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 11ce94f5942c..c8859c898519 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -17,10 +17,11 @@ namespace xgboost::tree { TEST(Updater, Refresh) { bst_row_t constexpr kRows = 8; bst_feature_t constexpr kCols = 16; + Context ctx; - HostDeviceVector gpair = - { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, - {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; + linalg::Matrix gpair + {{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, + {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }, {8, 1}, ctx.Device()}; std::shared_ptr p_dmat{ RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatrix()}; std::vector> cfg{ @@ -29,7 +30,6 @@ TEST(Updater, Refresh) { {"reg_lambda", "1"}}; RegTree tree = RegTree{1u, kCols}; - Context ctx; std::vector trees{&tree}; ObjInfo task{ObjInfo::kRegression}; diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index d125c84d55b0..dc9a9c2096c4 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -16,7 +16,7 @@ namespace xgboost { class UpdaterTreeStatTest : public ::testing::Test { protected: std::shared_ptr p_dmat_; - HostDeviceVector gpairs_; + linalg::Matrix gpairs_; size_t constexpr static kRows = 10; size_t constexpr static kCols = 10; @@ -24,8 +24,8 @@ class UpdaterTreeStatTest : public ::testing::Test { void SetUp() override { p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true); auto g = GenerateRandomGradients(kRows); - gpairs_.Resize(kRows); - gpairs_.Copy(g); + gpairs_.Reshape(kRows, 1); + gpairs_.Data()->Copy(g); } void RunTest(std::string updater) { @@ -63,7 +63,7 @@ TEST_F(UpdaterTreeStatTest, Approx) { this->RunTest("grow_histmaker"); } class UpdaterEtaTest : public ::testing::Test { protected: std::shared_ptr p_dmat_; - HostDeviceVector gpairs_; + linalg::Matrix gpairs_; size_t constexpr static kRows = 10; size_t constexpr static kCols = 10; size_t constexpr static kClasses = 10; @@ -71,8 +71,8 @@ class UpdaterEtaTest : public ::testing::Test { void SetUp() override { p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix(true, false, kClasses); auto g = GenerateRandomGradients(kRows); - gpairs_.Resize(kRows); - gpairs_.Copy(g); + gpairs_.Reshape(kRows, 1); + gpairs_.Data()->Copy(g); } void RunTest(std::string updater) { @@ -125,14 +125,15 @@ TEST_F(UpdaterEtaTest, GpuHist) { this->RunTest("grow_gpu_hist"); } class TestMinSplitLoss : public ::testing::Test { std::shared_ptr dmat_; - HostDeviceVector gpair_; + linalg::Matrix gpair_; void SetUp() override { constexpr size_t kRows = 32; constexpr size_t kCols = 16; constexpr float kSparsity = 0.6; dmat_ = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix(); - gpair_ = GenerateRandomGradients(kRows); + gpair_.Reshape(kRows, 1); + gpair_.Data()->Copy(GenerateRandomGradients(kRows)); } std::int32_t Update(Context const* ctx, std::string updater, float gamma) { diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index 530d3e9dfbab..9f902ce32a33 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -1,3 +1,4 @@ +import itertools import json import os import sys @@ -158,6 +159,96 @@ def test_classififer(): clf.fit(X, y) +@pytest.mark.parametrize( + "use_cupy,tree_method,device,order,gdtype,strategy", + [ + c + for c in itertools.product( + (True, False), + ("hist", "approx"), + ("cpu", "cuda"), + ("C", "F"), + ("float64", "float32"), + ("one_output_per_tree", "multi_output_tree"), + ) + ], +) +def test_custom_objective( + use_cupy: bool, + tree_method: str, + device: str, + order: str, + gdtype: str, + strategy: str, +) -> None: + from sklearn.datasets import load_iris + + X, y = load_iris(return_X_y=True) + + params = { + "tree_method": tree_method, + "device": device, + "n_estimators": 8, + "multi_strategy": strategy, + } + + obj = tm.softprob_obj(y.max() + 1, use_cupy=use_cupy, order=order, gdtype=gdtype) + + clf = xgb.XGBClassifier(objective=obj, **params) + + if strategy == "multi_output_tree" and tree_method == "approx": + with pytest.raises(ValueError, match=r"Only the hist"): + clf.fit(X, y) + return + if strategy == "multi_output_tree" and device == "cuda": + with pytest.raises(ValueError, match=r"GPU is not yet"): + clf.fit(X, y) + return + + clf.fit(X, y) + + clf_1 = xgb.XGBClassifier(**params) + clf_1.fit(X, y) + + np.testing.assert_allclose(clf.predict_proba(X), clf_1.predict_proba(X), rtol=1e-4) + + params["n_estimators"] = 2 + + def wrong_shape(labels, predt): + grad, hess = obj(labels, predt) + return grad[:, :-1], hess[:, :-1] + + with pytest.raises(ValueError, match="should be equal to the number of"): + clf = xgb.XGBClassifier(objective=wrong_shape, **params) + clf.fit(X, y) + + def wrong_shape_1(labels, predt): + grad, hess = obj(labels, predt) + return grad[:-1, :], hess[:-1, :] + + with pytest.raises(ValueError, match="Mismatched size between the gradient"): + clf = xgb.XGBClassifier(objective=wrong_shape_1, **params) + clf.fit(X, y) + + def wrong_shape_2(labels, predt): + grad, hess = obj(labels, predt) + return grad[:, :], hess[:-1, :] + + with pytest.raises(ValueError, match="Mismatched shape between the gradient"): + clf = xgb.XGBClassifier(objective=wrong_shape_2, **params) + clf.fit(X, y) + + def wrong_shape_3(labels, predt): + grad, hess = obj(labels, predt) + grad = grad.reshape(grad.size) + hess = hess.reshape(hess.size) + return grad, hess + + with pytest.warns(FutureWarning, match="required to be"): + clf = xgb.XGBClassifier(objective=wrong_shape_3, **params) + clf.fit(X, y) + + @pytest.mark.skipif(**tm.no_pandas()) def test_ranking_qid_df(): import cudf From 83e62d5ce7d3f4664b591e40d52a3c27a1837e50 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 23 Aug 2023 01:41:50 +0800 Subject: [PATCH 2/4] fix. --- python-package/xgboost/training.py | 2 +- src/objective/regression_obj.cu | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index aa3c18a01e8b..e74a56904f07 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -178,7 +178,7 @@ def train( for i in range(start_iteration, num_boost_round): if cb_container.before_iteration(bst, i, dtrain, evals): break - bst.update(dtrain, i, obj) + bst.update(dtrain, iteration=i, fobj=obj) if cb_container.after_iteration(bst, i, dtrain, evals): break diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 3c431ea682e1..5751d6102633 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -162,11 +162,6 @@ class RegLossObj : public FitIntercept { common::Range{0, static_cast(n_data_blocks)}, nthreads, device) .Eval(&additional_input_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); - - auto const flag = additional_input_.HostVector().begin()[0]; - if (flag == 0) { - LOG(FATAL) << Loss::LabelErrorMsg(); - } } public: From 3f8acdd25851073174db4ea9bf08688b6ae13bb2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 23 Aug 2023 01:45:08 +0800 Subject: [PATCH 3/4] typo. --- python-package/xgboost/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a9600c8fd331..dc76e7487e38 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2096,7 +2096,7 @@ def array_interface(array: NumpyOrCupy) -> bytes: if array.shape[0] != n_samples and is_flatten(array): warnings.warn( "Since 2.1.0, the shape of the gradient and hessian is required to" - " be (n_samples, n_targets) or (n_samples, n_targets).", + " be (n_samples, n_targets) or (n_samples, n_classes).", FutureWarning, ) array = array.reshape(n_samples, array.size // n_samples) From 63dafaedaac9735d58ffc86ebbefc29007a02feb Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 24 Aug 2023 01:49:56 +0800 Subject: [PATCH 4/4] comments. --- R-package/R/utils.R | 2 ++ demo/guide-python/multioutput_regression.py | 1 + include/xgboost/c_api.h | 13 +++++++------ include/xgboost/gbm.h | 3 ++- python-package/xgboost/core.py | 1 + src/c_api/c_api_utils.h | 12 ++++++++---- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index b6b14c06f540..5faca2ef492b 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -155,6 +155,8 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) { ntreelimit = 0) gpair <- obj(pred, dtrain) n_samples <- dim(dtrain)[1] + # We still require row-major in R as I'm not quite sure sure how to get the stride of + # the matrix in C. gpair$grad <- matrix(gpair$grad, nrow = n_samples, byrow = TRUE) gpair$hess <- matrix(gpair$hess, nrow = n_samples, byrow = TRUE) .Call( diff --git a/demo/guide-python/multioutput_regression.py b/demo/guide-python/multioutput_regression.py index a8a546b0c932..cc64e4e09680 100644 --- a/demo/guide-python/multioutput_regression.py +++ b/demo/guide-python/multioutput_regression.py @@ -82,6 +82,7 @@ def squared_log( ) -> Tuple[np.ndarray, np.ndarray]: grad = gradient(predt, dtrain) hess = hessian(predt, dtrain) + # both numpy.ndarray and cupy.ndarray works. return grad, hess def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index d2e14d752e0b..afc1f47fd334 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -950,16 +950,17 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, fl float *hess, bst_ulong len); /** - * @brief Update a multi-target model with gradient and Hessian. This is used for training - * with a custom objective function. + * @brief Update a model with gradient and Hessian. This is used for training with a + * custom objective function. * * @since 2.0.0 * * @param handle handle - * @param dtrain training data - * @param iter The current iteration number. - * @param grad Json encoded __(cuda)_array_interface__ for gradient. - * @param hess Json encoded __(cuda)_array_interface__ for Hessian. + * @param dtrain The training data. + * @param iter The current iteration round. When training continuation is used, the count + * should restart. + * @param grad Json encoded __(cuda)_array_interface__ for gradient. + * @param hess Json encoded __(cuda)_array_interface__ for Hessian. * * @return 0 when success, -1 when failure happens */ diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 3667421b094c..ae8652eee66d 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -73,7 +73,7 @@ class GradientBooster : public Model, public Configurable { /** * @brief Return number of boosted rounds. */ - [[nodiscard]] virtual int32_t BoostedRounds() const = 0; + [[nodiscard]] virtual std::int32_t BoostedRounds() const = 0; /** * \brief Whether the model has already been trained. When tree booster is chosen, then * returns true when there are existing trees. @@ -86,6 +86,7 @@ class GradientBooster : public Model, public Configurable { * @param in_gpair address of the gradient pair statistics of the data * @param prediction The output prediction cache entry that needs to be updated. * the booster may change content of gpair + * @param obj The objective function used for boosting. */ virtual void DoBoost(DMatrix* p_fmat, linalg::Matrix* in_gpair, PredictionCacheEntry*, ObjFunction const* obj) = 0; diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index dc76e7487e38..486cee514c67 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2085,6 +2085,7 @@ def is_flatten(array: NumpyOrCupy) -> bool: return len(array.shape) == 1 or array.shape[1] == 1 def array_interface(array: NumpyOrCupy) -> bytes: + # Can we check for __array_interface__ instead of a specific type instead? msg = ( "Expecting `np.ndarray` or `cupy.ndarray` for gradient and hessian." f" Got: {type(array)}" diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index af43951c0a27..e42eed633adb 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -350,6 +350,9 @@ void MakeSparseFromPtr(PtrT const *p_indptr, I const *p_indices, T const *p_data Json::Dump(jdata, data_str); } +/** + * @brief Make array interface for other language bindings. + */ template auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, std::size_t n_samples, std::size_t n_targets) { @@ -362,13 +365,13 @@ auto MakeGradientInterface(Context const *ctx, G const *grad, H const *hess, std return std::make_tuple(s_grad, s_hess); } -template +template struct CustomGradHessOp { - linalg::MatrixView t_grad; - linalg::MatrixView t_hess; + linalg::MatrixView t_grad; + linalg::MatrixView t_hess; linalg::MatrixView d_gpair; - CustomGradHessOp(linalg::MatrixView t_grad, linalg::MatrixView t_hess, + CustomGradHessOp(linalg::MatrixView t_grad, linalg::MatrixView t_hess, linalg::MatrixView d_gpair) : t_grad{std::move(t_grad)}, t_hess{std::move(t_hess)}, d_gpair{std::move(d_gpair)} {} @@ -376,6 +379,7 @@ struct CustomGradHessOp { auto [m, n] = linalg::UnravelIndex(i, t_grad.Shape(0), t_grad.Shape(1)); auto g = t_grad(m, n); auto h = t_hess(m, n); + // from struct of arrays to array of structs. d_gpair(m, n) = GradientPair{static_cast(g), static_cast(h)}; } };