diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 31bb430f0aed..b6286c238006 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -166,10 +166,11 @@ class LIGHTGBM_EXPORT Boosting { * \brief Feature contributions for the model's prediction of one record * \param feature_values Feature value on this record * \param output Prediction result for this record - * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated. */ - virtual void PredictContrib(const double* features, double* output, - const PredictionEarlyStopInstance* early_stop) const = 0; + virtual void PredictContrib(const double* features, double* output) const = 0; + + virtual void PredictContribByMap(const std::unordered_map& features, + std::vector>* output) const = 0; /*! * \brief Dump model to json format string diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 6a30fce495c5..55204295d51b 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -727,6 +727,59 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, int64_t* out_len, double* out_result); +/*! + * \brief Make sparse prediction for a new dataset in CSR format. Currently only used for feature contributions. + * \note + * The outputs are pre-allocated, as they can vary for each invocation, but the shape should be the same: + * - for feature contributions, the shape of sparse matrix will be ``num_class * num_data * (num_feature + 1)``. + * The output indptr_type for the sparse matrix will be the same as the given input indptr_type. + * \param handle Handle of booster + * \param indptr Pointer to row headers + * \param indptr_type Type of ``indptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64`` + * \param indices Pointer to column indices + * \param data Pointer to the data space + * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` + * \param nindptr Number of rows in the matrix + 1 + * \param nelem Number of nonzero elements in the matrix + * \param num_col Number of columns + * \param predict_type What should be predicted, only feature contributions supported currently + * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) + * \param num_iteration Number of iterations for prediction, <= 0 means no limit + * \param parameter Other parameters for prediction, e.g. early stopping for prediction + * \param[out] out_len Length of output indices and data + * \param[out] out_indptr Pointer to output row headers + * \param[out] out_indices Pointer to sparse indices + * \param[out] out_data Pointer to sparse data space + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseForCSR(BoosterHandle handle, + const void* indptr, + int indptr_type, + const int32_t* indices, + const void* data, + int data_type, + int64_t nindptr, + int64_t nelem, + int64_t num_col, + int predict_type, + int num_iteration, + const char* parameter, + int64_t* out_len, + void** out_indptr, + int32_t** out_indices, + void** out_data); + +/*! + * \brief Method corresponding to LGBM_BoosterPredictSparseForCSR to free the allocated data. + * \param indptr Pointer to output row headers or col headers to be deallocated + * \param indices Pointer to sparse indices to be deallocated + * \param data Pointer to sparse data space to be deallocated + * \param indptr_type Type of ``indptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64`` + * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indices, void* data, int indptr_type, int data_type); + /*! * \brief Make prediction for a new dataset in CSR format. This method re-uses the internal predictor structure * from previous calls and is optimized for single row invocation. @@ -812,6 +865,48 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, int64_t* out_len, double* out_result); +/*! + * \brief Make sparse prediction for a new dataset in CSC format. Currently only used for feature contributions. + * \note + * The outputs are pre-allocated, as they can vary for each invocation, but the shape should be the same: + * - for feature contributions, the shape of sparse matrix will be ``num_class * num_data * (num_feature + 1)``. + * The output indptr_type for the sparse matrix will be the same as the given input indptr_type. + * \param handle Handle of booster + * \param col_ptr Pointer to column headers + * \param col_ptr_type Type of ``col_ptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64`` + * \param indices Pointer to row indices + * \param data Pointer to the data space + * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` + * \param ncol_ptr Number of columns in the matrix + 1 + * \param nelem Number of nonzero elements in the matrix + * \param num_row Number of rows + * \param predict_type What should be predicted + * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) + * \param num_iteration Number of iteration for prediction, <= 0 means no limit + * \param parameter Other parameters for prediction, e.g. early stopping for prediction + * \param[out] out_len Length of output indices and data + * \param[out] out_col_ptr Pointer to output column headers + * \param[out] out_indices Pointer to sparse indices + * \param[out] out_data Pointer to sparse data space + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseForCSC(BoosterHandle handle, + const void* col_ptr, + int col_ptr_type, + const int32_t* indices, + const void* data, + int data_type, + int64_t ncol_ptr, + int64_t nelem, + int64_t num_row, + int predict_type, + int num_iteration, + const char* parameter, + int64_t* out_len, + void** out_col_ptr, + int32_t** out_indices, + void** out_data); + /*! * \brief Make prediction for a new dataset. * \note diff --git a/include/LightGBM/meta.h b/include/LightGBM/meta.h index b15b1ba4b378..bcaa214292b8 100644 --- a/include/LightGBM/meta.h +++ b/include/LightGBM/meta.h @@ -11,6 +11,7 @@ #include #include #include +#include #if (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_AMD64))) || defined(__INTEL_COMPILER) || MM_PREFETCH #include @@ -58,6 +59,9 @@ typedef int32_t comm_size_t; using PredictFunction = std::function>&, double* output)>; +using PredictSparseFunction = +std::function>&, std::vector>* output)>; + typedef void(*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size); diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 047215231fc6..0edd1af4b719 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -135,6 +135,8 @@ class Tree { inline int PredictLeafIndexByMap(const std::unordered_map& feature_values) const; inline void PredictContrib(const double* feature_values, int num_features, double* output); + inline void PredictContribByMap(const std::unordered_map& feature_values, + int num_features, std::unordered_map* output); /*! \brief Get Number of leaves*/ inline int num_leaves() const { return num_leaves_; } @@ -382,6 +384,12 @@ class Tree { PathElement *parent_unique_path, double parent_zero_fraction, double parent_one_fraction, int parent_feature_index) const; + void TreeSHAPByMap(const std::unordered_map& feature_values, + std::unordered_map* phi, + int node, int unique_depth, + PathElement *parent_unique_path, double parent_zero_fraction, + double parent_one_fraction, int parent_feature_index) const; + /*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/ static void ExtendPath(PathElement *unique_path, int unique_depth, double zero_fraction, double one_fraction, int feature_index); @@ -525,6 +533,18 @@ inline void Tree::PredictContrib(const double* feature_values, int num_features, } } +inline void Tree::PredictContribByMap(const std::unordered_map& feature_values, + int num_features, std::unordered_map* output) { + (*output)[num_features] += ExpectedValue(); + // Run the recursion with preallocated space for the unique path data + if (num_leaves_ > 1) { + CHECK_GE(max_depth_, 0); + const int max_path_len = max_depth_ + 1; + std::vector unique_path_data(max_path_len*(max_path_len + 1) / 2); + TreeSHAPByMap(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1); + } +} + inline void Tree::RecomputeLeafDepths(int node, int depth) { if (node == 0) leaf_depth_.resize(num_leaves()); if (node < 0) { diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5d9255d5be74..95c0fc5478cd 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -5,6 +5,7 @@ import copy import ctypes import os +import sys import warnings from tempfile import NamedTemporaryFile from collections import OrderedDict @@ -511,8 +512,8 @@ def predict(self, data, num_iteration=-1, Returns ------- - result : numpy array - Prediction result. + result : numpy array or scipy.sparse or list[scipy.sparse] + Prediction result, can be sparse for feature contributions (when pred_contrib=True). """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") @@ -637,6 +638,58 @@ def inner_predict(mat, num_iteration, predict_type, preds=None): else: return inner_predict(mat, num_iteration, predict_type) + def __convert_ctypes_to_numpy(self, ctypes_array, ctypes_array_len, array_type): + """Convert the ctypes array to a numpy array, note memory will still need to be managed.""" + array_size = np.dtype(array_type).itemsize * ctypes_array_len + if sys.version_info.major >= 3: + bfm = ctypes.pythonapi.PyMemoryView_FromMemory + bfm.restype = ctypes.py_object + bfm.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_int) + PyBUF_READ = 0x100 + py_buffer = bfm(ctypes_array, array_size, PyBUF_READ) + else: + bfm = ctypes.pythonapi.PyBuffer_FromMemory + bfm.restype = ctypes.py_object + bfm.argtypes = (ctypes.c_void_p, ctypes.c_int) + py_buffer = bfm(ctypes_array, array_size) + return np.frombuffer(py_buffer, array_type) + + def __create_sparse_native(self, cs, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, + np_indptr_type, np_data_type, indptr_type, data_type, is_csr=True): + # create numpy array from output arrays + data_indices_len = out_shape[0] + indptr_len = out_shape[1] + out_indptr = self.__convert_ctypes_to_numpy(out_ptr_indptr, indptr_len, np_indptr_type) + out_indices = self.__convert_ctypes_to_numpy(out_ptr_indices, data_indices_len, np.int32) + out_data = self.__convert_ctypes_to_numpy(out_ptr_data, data_indices_len, np_data_type) + # break up indptr based on number of rows (note more than one matrix in multiclass case) + per_class_indptr_shape = cs.indptr.shape[0] + # for CSC there is extra column added + if not is_csr: + per_class_indptr_shape += 1 + out_indptr_arrays = np.split(out_indptr, out_indptr.shape[0] / per_class_indptr_shape) + # reformat output into a csr or csc matrix or list of csr or csc matrices + cs_output_matrices = [] + offset = 0 + for cs_indptr in out_indptr_arrays: + matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1] + cs_indices = out_indices[offset + cs_indptr[0]:offset + matrix_indptr_len] + cs_data = out_data[offset + cs_indptr[0]:offset + matrix_indptr_len] + offset += matrix_indptr_len + # same shape as input csr or csc matrix except extra column for expected value + cs_shape = [cs.shape[0], cs.shape[1] + 1] + # note: make sure we copy data as it will be deallocated next + if is_csr: + cs_output_matrices.append(scipy.sparse.csr_matrix((cs_data, cs_indices, cs_indptr), cs_shape, copy=True)) + else: + cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape, copy=True)) + # free the temporary native indptr, indices, and data + _safe_call(_LIB.LGBM_BoosterFreePredictSparse(out_ptr_indptr, out_ptr_indices, out_ptr_data, + ctypes.c_int(indptr_type), ctypes.c_int(data_type))) + if len(cs_output_matrices) == 1: + return cs_output_matrices[0] + return cs_output_matrices + def __pred_for_csr(self, csr, num_iteration, predict_type): """Predict for a CSR data.""" def inner_predict(csr, num_iteration, predict_type, preds=None): @@ -652,6 +705,7 @@ def inner_predict(csr, num_iteration, predict_type, preds=None): ptr_data, type_ptr_data, _ = c_float_array(csr.data) assert csr.shape[1] <= MAX_INT32 + # Note: not sure why this was csr.indices, we shouldn't be modifying the input matrix csr.indices = csr.indices.astype(np.int32, copy=False) _safe_call(_LIB.LGBM_BoosterPredictForCSR( @@ -673,6 +727,49 @@ def inner_predict(csr, num_iteration, predict_type, preds=None): raise ValueError("Wrong length for predict results") return preds, nrow + def inner_predict_sparse(csr, num_iteration, predict_type): + ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr) + ptr_data, type_ptr_data, _ = c_float_array(csr.data) + csr.indices = csr.indices.astype(np.int32, copy=False) + if type_ptr_indptr == C_API_DTYPE_INT32: + np_indptr_type = np.int32 + out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)() + else: + np_indptr_type = np.int64 + out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)() + out_ptr_indices = ctypes.POINTER(ctypes.c_int32)() + if type_ptr_data == C_API_DTYPE_FLOAT32: + np_data_type = np.float32 + out_ptr_data = ctypes.POINTER(ctypes.c_float)() + else: + np_data_type = np.float64 + out_ptr_data = ctypes.POINTER(ctypes.c_double)() + out_shape = np.zeros(2, dtype=np.int64) + _safe_call(_LIB.LGBM_BoosterPredictSparseForCSR( + self.handle, + ptr_indptr, + ctypes.c_int32(type_ptr_indptr), + csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + ptr_data, + ctypes.c_int(type_ptr_data), + ctypes.c_int64(len(csr.indptr)), + ctypes.c_int64(len(csr.data)), + ctypes.c_int64(csr.shape[1]), + ctypes.c_int(predict_type), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), + ctypes.byref(out_ptr_indptr), + ctypes.byref(out_ptr_indices), + ctypes.byref(out_ptr_data))) + matrices = self.__create_sparse_native(csr, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, + np_indptr_type, np_data_type, type_ptr_indptr, + type_ptr_data, is_csr=True) + nrow = len(csr.indptr) - 1 + return matrices, nrow + + if predict_type == C_API_PREDICT_CONTRIB: + return inner_predict_sparse(csr, num_iteration, predict_type) nrow = len(csr.indptr) - 1 if nrow > MAX_INT32: sections = [0] + list(np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32)) + [nrow] @@ -690,6 +787,49 @@ def inner_predict(csr, num_iteration, predict_type, preds=None): def __pred_for_csc(self, csc, num_iteration, predict_type): """Predict for a CSC data.""" + def inner_predict_sparse(csc, num_iteration, predict_type): + ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr) + ptr_data, type_ptr_data, _ = c_float_array(csc.data) + csc.indices = csc.indices.astype(np.int32, copy=False) + if type_ptr_indptr == C_API_DTYPE_INT32: + np_indptr_type = np.int32 + out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)() + else: + np_indptr_type = np.int64 + out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)() + out_ptr_indices = ctypes.POINTER(ctypes.c_int32)() + if type_ptr_data == C_API_DTYPE_FLOAT32: + np_data_type = np.float32 + out_ptr_data = ctypes.POINTER(ctypes.c_float)() + else: + np_data_type = np.float64 + out_ptr_data = ctypes.POINTER(ctypes.c_double)() + out_shape = np.zeros(2, dtype=np.int64) + _safe_call(_LIB.LGBM_BoosterPredictSparseForCSC( + self.handle, + ptr_indptr, + ctypes.c_int32(type_ptr_indptr), + csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + ptr_data, + ctypes.c_int(type_ptr_data), + ctypes.c_int64(len(csc.indptr)), + ctypes.c_int64(len(csc.data)), + ctypes.c_int64(csc.shape[0]), + ctypes.c_int(predict_type), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), + ctypes.byref(out_ptr_indptr), + ctypes.byref(out_ptr_indices), + ctypes.byref(out_ptr_data))) + matrices = self.__create_sparse_native(csc, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, + np_indptr_type, np_data_type, type_ptr_indptr, + type_ptr_data, is_csr=False) + nrow = csc.shape[0] + return matrices, nrow + + if predict_type == C_API_PREDICT_CONTRIB: + return inner_predict_sparse(csc, num_iteration, predict_type) nrow = csc.shape[0] if nrow > MAX_INT32: return self.__pred_for_csr(csc.tocsr(), num_iteration, predict_type) diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index 1c56cfa5eb2c..48ef227de2c6 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -88,12 +88,18 @@ class Predictor { double* output) { int tid = omp_get_thread_num(); CopyToPredictBuffer(predict_buf_[tid].data(), features); - // get result for leaf index - boosting_->PredictContrib(predict_buf_[tid].data(), output, - &early_stop_); + // get feature importances + boosting_->PredictContrib(predict_buf_[tid].data(), output); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); }; + predict_sparse_fun_ = [=](const std::vector>& features, + std::vector>* output) { + auto buf = CopyToPredictMap(features); + // get sparse feature importances + boosting_->PredictContribByMap(buf, output); + }; + } else { if (is_raw_score) { predict_fun_ = [=](const std::vector>& features, @@ -140,6 +146,11 @@ class Predictor { return predict_fun_; } + + inline const PredictSparseFunction& GetPredictSparseFunction() const { + return predict_sparse_fun_; + } + /*! * \brief predicting on data, then saving result to disk * \param data_filename Filename of data @@ -275,6 +286,7 @@ class Predictor { const Boosting* boosting_; /*! \brief function for prediction */ PredictFunction predict_fun_; + PredictSparseFunction predict_sparse_fun_; PredictionEarlyStopInstance early_stop_; int num_feature_; int num_pred_one_row_; diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 6a2e3e27c791..cf4a615b4696 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -571,8 +571,7 @@ const double* GBDT::GetTrainingScore(int64_t* out_len) { return train_score_updater_->score(); } -void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const { - int early_stop_round_counter = 0; +void GBDT::PredictContrib(const double* features, double* output) const { // set zero const int num_features = max_feature_idx_ + 1; std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1)); @@ -581,13 +580,16 @@ void GBDT::PredictContrib(const double* features, double* output, const Predicti for (int k = 0; k < num_tree_per_iteration_; ++k) { models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1)); } - // check early stopping - ++early_stop_round_counter; - if (early_stop->round_period == early_stop_round_counter) { - if (early_stop->callback_function(output, num_tree_per_iteration_)) { - return; - } - early_stop_round_counter = 0; + } +} + +void GBDT::PredictContribByMap(const std::unordered_map& features, + std::vector>* output) const { + const int num_features = max_feature_idx_ + 1; + for (int i = 0; i < num_iteration_for_pred_; ++i) { + // predict all the trees for one iteration + for (int k = 0; k < num_tree_per_iteration_; ++k) { + models_[i * num_tree_per_iteration_ + k]->PredictContribByMap(features, num_features, &((*output)[k])); } } } diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 67c30c86be2e..6f14db2af2f3 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -210,18 +210,18 @@ class GBDT : public GBDTBase { * \return number of prediction */ inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override { - int num_preb_in_one_row = num_class_; + int num_pred_in_one_row = num_class_; if (is_pred_leaf) { int max_iteration = GetCurrentIteration(); if (num_iteration > 0) { - num_preb_in_one_row *= static_cast(std::min(max_iteration, num_iteration)); + num_pred_in_one_row *= static_cast(std::min(max_iteration, num_iteration)); } else { - num_preb_in_one_row *= max_iteration; + num_pred_in_one_row *= max_iteration; } } else if (is_pred_contrib) { - num_preb_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline + num_pred_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline } - return num_preb_in_one_row; + return num_pred_in_one_row; } void PredictRaw(const double* features, double* output, @@ -240,8 +240,10 @@ class GBDT : public GBDTBase { void PredictLeafIndexByMap(const std::unordered_map& features, double* output) const override; - void PredictContrib(const double* features, double* output, - const PredictionEarlyStopInstance* earlyStop) const override; + void PredictContrib(const double* features, double* output) const override; + + void PredictContribByMap(const std::unordered_map& features, + std::vector>* output) const override; /*! * \brief Dump model to json format string diff --git a/src/c_api.cpp b/src/c_api.cpp index 5066947a1482..97e1b8859d50 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -382,16 +382,11 @@ class Booster { *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row; } - - void Predict(int num_iteration, int predict_type, int nrow, int ncol, - std::function>(int row_idx)> get_row_fun, - const Config& config, - double* out_result, int64_t* out_len) { + Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \ "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1); } - std::lock_guard lock(mutex_); bool is_predict_leaf = false; bool is_raw_score = false; bool predict_contrib = false; @@ -407,6 +402,22 @@ class Booster { Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib, config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); + return predictor; + } + + void Predict(int num_iteration, int predict_type, int nrow, int ncol, + std::function>(int row_idx)> get_row_fun, + const Config& config, + double* out_result, int64_t* out_len) { + std::lock_guard lock(mutex_); + auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); + bool is_predict_leaf = false; + bool predict_contrib = false; + if (predict_type == C_API_PREDICT_LEAF_INDEX) { + is_predict_leaf = true; + } else if (predict_type == C_API_PREDICT_CONTRIB) { + predict_contrib = true; + } int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib); auto pred_fun = predictor.GetPredictFunction(); OMP_INIT_EX(); @@ -422,6 +433,214 @@ class Booster { *out_len = num_pred_in_one_row * nrow; } + void PredictSparse(int num_iteration, int predict_type, int nrow, int ncol, + std::function>(int row_idx)> get_row_fun, + const Config& config, int* out_elements_size, + std::vector>> *agg_ptr, + int32_t** out_indices, void** out_data, int data_type, + bool* is_data_float32_ptr, int num_matrices) { + auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); + auto pred_sparse_fun = predictor.GetPredictSparseFunction(); + std::vector>>& agg = *agg_ptr; + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int i = 0; i < nrow; ++i) { + OMP_LOOP_EX_BEGIN(); + auto one_row = get_row_fun(i); + agg[i] = std::vector>(num_matrices); + pred_sparse_fun(one_row, &agg[i]); + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + // calculate the nonzero data and indices size + int elements_size = 0; + for (int i = 0; i < static_cast(agg.size()); i++) { + auto row_vector = agg[i]; + for (int j = 0; j < static_cast(row_vector.size()); j++) { + elements_size += row_vector[j].size(); + } + } + *out_elements_size = elements_size; + *is_data_float32_ptr = false; + // allocate data and indices arrays + if (data_type == C_API_DTYPE_FLOAT32) { + *out_data = new float[elements_size]; + *is_data_float32_ptr = true; + } else if (data_type == C_API_DTYPE_FLOAT64) { + *out_data = new double[elements_size]; + } else { + Log::Fatal("Unknown data type in PredictSparse"); + return; + } + *out_indices = new int32_t[elements_size]; + } + + void PredictSparseCSR(int num_iteration, int predict_type, int nrow, int ncol, + std::function>(int row_idx)> get_row_fun, + const Config& config, + int64_t* out_len, void** out_indptr, int indptr_type, + int32_t** out_indices, void** out_data, int data_type) { + std::lock_guard lock(mutex_); + // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) + int num_matrices = boosting_->NumModelPerIteration(); + bool is_indptr_int32 = false; + bool is_data_float32 = false; + int indptr_size = (nrow + 1) * num_matrices; + if (indptr_type == C_API_DTYPE_INT32) { + *out_indptr = new int32_t[indptr_size]; + is_indptr_int32 = true; + } else if (indptr_type == C_API_DTYPE_INT64) { + *out_indptr = new int64_t[indptr_size]; + } else { + Log::Fatal("Unknown indptr type in PredictSparseCSR"); + return; + } + // aggregated per row feature contribution results + std::vector>> agg(nrow); + int elements_size = 0; + PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg, + out_indices, out_data, data_type, &is_data_float32, num_matrices); + // copy vector results to output for each row + // TODO(imatiach-msft): should this be parallelized, or is there no perf benefit? + int element_index = 0; + int indptr_index = 0; + for (int m = 0; m < num_matrices; m++) { + int indptr_value = 0; + if (is_indptr_int32) { + (reinterpret_cast(*out_indptr))[indptr_index] = indptr_value; + } else { + (reinterpret_cast(*out_indptr))[indptr_index] = indptr_value; + } + indptr_index++; + for (int i = 0; i < static_cast(agg.size()); i++) { + auto row_vector = agg[i]; + for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) { + (*out_indices)[element_index] = it->first; + if (is_data_float32) { + (reinterpret_cast(*out_data))[element_index] = it->second; + } else { + (reinterpret_cast(*out_data))[element_index] = it->second; + } + element_index++; + indptr_value++; + } + if (is_indptr_int32) { + (reinterpret_cast(*out_indptr))[indptr_index] = indptr_value; + } else { + (reinterpret_cast(*out_indptr))[indptr_index] = indptr_value; + } + indptr_index++; + } + } + out_len[0] = elements_size; + out_len[1] = indptr_size; + } + + + void PredictSparseCSC(int num_iteration, int predict_type, int nrow, int ncol, + std::function>(int row_idx)> get_row_fun, + const Config& config, + int64_t* out_len, void** out_col_ptr, int col_ptr_type, + int32_t** out_indices, void** out_data, int data_type) { + std::lock_guard lock(mutex_); + // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) + int num_matrices = boosting_->NumModelPerIteration(); + auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); + auto pred_sparse_fun = predictor.GetPredictSparseFunction(); + bool is_col_ptr_int32 = false; + bool is_data_float32 = false; + int num_output_cols = ncol + 1; + int col_ptr_size = (num_output_cols + 1) * num_matrices; + if (col_ptr_type == C_API_DTYPE_INT32) { + *out_col_ptr = new int32_t[col_ptr_size]; + is_col_ptr_int32 = true; + } else if (col_ptr_type == C_API_DTYPE_INT64) { + *out_col_ptr = new int64_t[col_ptr_size]; + } else { + Log::Fatal("Unknown col_ptr type in PredictSparseCSC"); + return; + } + // aggregated per row feature contribution results + std::vector>> agg(nrow); + int elements_size = 0; + PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg, + out_indices, out_data, data_type, &is_data_float32, num_matrices); + // calculate number of elements per column to construct + // the CSC matrix with random access + std::vector> column_sizes(num_matrices); + for (int m = 0; m < num_matrices; m++) { + column_sizes[m] = std::vector(num_output_cols, 0); + for (int i = 0; i < static_cast(agg.size()); i++) { + auto row_vector = agg[i]; + for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) { + column_sizes[m][it->first] += 1; + } + } + } + // keep track of column counts + std::vector> column_counts(num_matrices); + // keep track of beginning index for each column + std::vector> column_start_indices(num_matrices); + // keep track of beginning index for each matrix + std::vector matrix_start_indices(num_matrices, 0); + int col_ptr_index = 0; + for (int m = 0; m < num_matrices; m++) { + int col_ptr_value = 0; + column_start_indices[m] = std::vector(num_output_cols, 0); + column_counts[m] = std::vector(num_output_cols, 0); + if (is_col_ptr_int32) { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = col_ptr_value; + } else { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = col_ptr_value; + } + col_ptr_index++; + for (int i = 1; i < static_cast(column_sizes[m].size()); i++) { + column_start_indices[m][i] = column_sizes[m][i - 1] + column_start_indices[m][i - 1]; + if (is_col_ptr_int32) { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = column_start_indices[m][i]; + } else { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = column_start_indices[m][i]; + } + col_ptr_index++; + } + int last_elem_index = column_sizes[m].size() - 1; + int last_column_start_index = column_start_indices[m][last_elem_index]; + int last_column_size = column_sizes[m][last_elem_index]; + if (is_col_ptr_int32) { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = last_column_start_index + last_column_size; + } else { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = last_column_start_index + last_column_size; + } + if (m != 0) { + matrix_start_indices[m] = matrix_start_indices[m - 1] + + last_column_start_index + + last_column_size; + } + } + for (int m = 0; m < num_matrices; m++) { + for (int i = 0; i < static_cast(agg.size()); i++) { + auto row_vector = agg[i]; + for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) { + int col_idx = it->first; + int element_index = column_start_indices[m][col_idx] + + matrix_start_indices[m] + + column_counts[m][col_idx]; + // store the row index + (*out_indices)[element_index] = i; + // update column count + column_counts[m][col_idx]++; + if (is_data_float32) { + (reinterpret_cast(*out_data))[element_index] = it->second; + } else { + (reinterpret_cast(*out_data))[element_index] = it->second; + } + } + } + } + out_len[0] = elements_size; + out_len[1] = col_ptr_size; + } + void Predict(int num_iteration, int predict_type, const char* data_filename, int data_has_header, const Config& config, const char* result_filename) { @@ -1513,6 +1732,62 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, API_END(); } +int LGBM_BoosterPredictSparseForCSR(BoosterHandle handle, + const void* indptr, + int indptr_type, + const int32_t* indices, + const void* data, + int data_type, + int64_t nindptr, + int64_t nelem, + int64_t num_col, + int predict_type, + int num_iteration, + const char* parameter, + int64_t* out_len, + void** out_indptr, + int32_t** out_indices, + void** out_data) { + API_BEGIN(); + if (num_col <= 0) { + Log::Fatal("The number of columns should be greater than zero."); + } else if (num_col >= INT32_MAX) { + Log::Fatal("The number of columns should be smaller than INT32_MAX."); + } + auto param = Config::Str2Map(parameter); + Config config; + config.Set(param); + if (config.num_threads > 0) { + omp_set_num_threads(config.num_threads); + } + Booster* ref_booster = reinterpret_cast(handle); + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + int nrow = static_cast(nindptr - 1); + ref_booster->PredictSparseCSR(num_iteration, predict_type, nrow, static_cast(num_col), get_row_fun, + config, out_len, out_indptr, indptr_type, out_indices, out_data, data_type); + API_END(); +} + +int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indices, void* data, int indptr_type, int data_type) { + API_BEGIN(); + if (indptr_type == C_API_DTYPE_INT32) { + delete reinterpret_cast(indptr); + } else if (indptr_type == C_API_DTYPE_INT64) { + delete reinterpret_cast(indptr); + } else { + Log::Fatal("Unknown indptr type in LGBM_BoosterFreePredictSparse"); + } + delete indices; + if (data_type == C_API_DTYPE_FLOAT32) { + delete reinterpret_cast(data); + } else if (data_type == C_API_DTYPE_FLOAT64) { + delete reinterpret_cast(data); + } else { + Log::Fatal("Unknown data type in LGBM_BoosterFreePredictSparse"); + } + API_END(); +} + int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, const void* indptr, int indptr_type, @@ -1594,6 +1869,56 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, API_END(); } +int LGBM_BoosterPredictSparseForCSC(BoosterHandle handle, + const void* col_ptr, + int col_ptr_type, + const int32_t* indices, + const void* data, + int data_type, + int64_t ncol_ptr, + int64_t nelem, + int64_t num_row, + int predict_type, + int num_iteration, + const char* parameter, + int64_t* out_len, + void** out_col_ptr, + int32_t** out_indices, + void** out_data) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + auto param = Config::Str2Map(parameter); + Config config; + config.Set(param); + if (config.num_threads > 0) { + omp_set_num_threads(config.num_threads); + } + int num_threads = OMP_NUM_THREADS(); + int ncol = static_cast(ncol_ptr - 1); + std::vector> iterators(num_threads, std::vector()); + for (int i = 0; i < num_threads; ++i) { + for (int j = 0; j < ncol; ++j) { + iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j); + } + } + std::function>(int row_idx)> get_row_fun = + [&iterators, ncol](int i) { + std::vector> one_row; + one_row.reserve(ncol); + const int tid = omp_get_thread_num(); + for (int j = 0; j < ncol; ++j) { + auto val = iterators[tid][j].Get(i); + if (std::fabs(val) > kZeroThreshold || std::isnan(val)) { + one_row.emplace_back(j, val); + } + } + return one_row; + }; + ref_booster->PredictSparseCSC(num_iteration, predict_type, static_cast(num_row), ncol, get_row_fun, config, + out_len, out_col_ptr, col_ptr_type, out_indices, out_data, data_type); + API_END(); +} + int LGBM_BoosterPredictForMat(BoosterHandle handle, const void* data, int data_type, diff --git a/src/io/tree.cpp b/src/io/tree.cpp index be928b7e3124..94969b22f439 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -727,6 +727,56 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, } } +// recursive sparse computation of SHAP values for a decision tree +void Tree::TreeSHAPByMap(const std::unordered_map& feature_values, std::unordered_map* phi, + int node, int unique_depth, + PathElement *parent_unique_path, double parent_zero_fraction, + double parent_one_fraction, int parent_feature_index) const { + // extend the unique path + PathElement* unique_path = parent_unique_path + unique_depth; + if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path); + ExtendPath(unique_path, unique_depth, parent_zero_fraction, + parent_one_fraction, parent_feature_index); + + // leaf node + if (node < 0) { + for (int i = 1; i <= unique_depth; ++i) { + const double w = UnwoundPathSum(unique_path, unique_depth, i); + const PathElement &el = unique_path[i]; + (*phi)[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node]; + } + + // internal node + } else { + const int hot_index = Decision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node); + const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]); + const double w = data_count(node); + const double hot_zero_fraction = data_count(hot_index) / w; + const double cold_zero_fraction = data_count(cold_index) / w; + double incoming_zero_fraction = 1; + double incoming_one_fraction = 1; + + // see if we have already split on this feature, + // if so we undo that split so we can redo it for this node + int path_index = 0; + for (; path_index <= unique_depth; ++path_index) { + if (unique_path[path_index].feature_index == split_feature_[node]) break; + } + if (path_index != unique_depth + 1) { + incoming_zero_fraction = unique_path[path_index].zero_fraction; + incoming_one_fraction = unique_path[path_index].one_fraction; + UnwindPath(unique_path, unique_depth, path_index); + unique_depth -= 1; + } + + TreeSHAPByMap(feature_values, phi, hot_index, unique_depth + 1, unique_path, + hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]); + + TreeSHAPByMap(feature_values, phi, cold_index, unique_depth + 1, unique_path, + cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]); + } +} + double Tree::ExpectedValue() const { if (num_leaves_ == 1) return LeafOutput(0); const double total_count = internal_count_[0]; diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 3be0568e622a..699426fd4282 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -14,6 +14,7 @@ load_iris, load_svmlight_file) from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold +from sklearn.datasets import make_multilabel_classification try: import cPickle as pickle @@ -941,6 +942,40 @@ def test_contribs(self): self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(gbm.predict(X_test, pred_contrib=True), axis=1)), 1e-4) + def test_contribs_sparse(self): + n_features = 20 + n_samples = 100 + # generate CSR sparse dataset + X, y = make_multilabel_classification(n_samples=n_samples, + sparse=True, + n_features=n_features, + n_classes=1, + n_labels=2) + y = y.flatten() + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + params = { + 'objective': 'binary', + 'metric': 'binary_logloss', + 'verbose': -1, + } + lgb_train = lgb.Dataset(X_train, y_train) + gbm = lgb.train(params, lgb_train, num_boost_round=20) + contribs_csr = gbm.predict(X_test, pred_contrib=True) + # convert data to dense and get back same contribs + contribs_dense = gbm.predict(X_test.toarray(), pred_contrib=True) + # validate the values are the same + np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense) + self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) + - np.sum(contribs_dense, axis=1)), 1e-4) + # validate using CSC matrix + X_train_csc = X_train.tocsc() + X_test_csc = X_test.tocsc() + lgb_train = lgb.Dataset(X_train_csc, y_train) + gbm = lgb.train(params, lgb_train, num_boost_round=20) + contribs_csc = gbm.predict(X_test_csc, pred_contrib=True) + # validate the values are the same + np.testing.assert_allclose(contribs_csr.toarray(), contribs_csc.toarray()) + def test_sliced_data(self): def train_and_get_predictions(features, labels): dataset = lgb.Dataset(features, label=labels)