Skip to content

Commit

Permalink
adding sparse support to TreeSHAP in lightgbm
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Apr 30, 2020
1 parent 2c18a0f commit b6c6af1
Show file tree
Hide file tree
Showing 11 changed files with 700 additions and 31 deletions.
7 changes: 4 additions & 3 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,11 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Feature contributions for the model's prediction of one record
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/
virtual void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0;
virtual void PredictContrib(const double* features, double* output) const = 0;

virtual void PredictContribByMap(const std::unordered_map<int, double>& features,
std::vector<std::unordered_map<int, double>>& output) const = 0;

/*!
* \brief Dump model to json format string
Expand Down
93 changes: 93 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,57 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t* out_len,
double* out_result);

/*!
* \brief Make sparse prediction for a new dataset in CSR format. Currently only used for feature contributions.
* \note
* The outputs are pre-allocated, as they can vary for each invocation, but the shape should be the same:
* - for feature contributions, the shape of sparse matrix will be ``num_class * num_data * (num_feature + 1)``.
* The output indptr_type for the sparse matrix will be the same as the given input indptr_type.
* \param handle Handle of booster
* \param indptr Pointer to row headers
* \param indptr_type Type of ``indptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64``
* \param indices Pointer to column indices
* \param data Pointer to the data space
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix
* \param num_col Number of columns
* \param predict_type What should be predicted, only feature contributions supported currently
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output indices and data
* \param[out] out_indptr Pointer to output row headers
* \param[out] out_indices Pointer to sparse indices
* \param[out] out_data Pointer to sparse data space
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseForCSR(BoosterHandle handle,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len,
void** out_indptr,
int32_t** out_indices,
void** out_data);

/*!
* \brief Method corresponding to LGBM_BoosterPredictSparseForCSR to free the allocated data.
* \param indptr Pointer to output row headers or col headers to be deallocated
* \param indices Pointer to sparse indices to be deallocated
* \param data Pointer to sparse data space to be deallocated
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indices, void* data);

/*!
* \brief Make prediction for a new dataset in CSR format. This method re-uses the internal predictor structure
* from previous calls and is optimized for single row invocation.
Expand Down Expand Up @@ -812,6 +863,48 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t* out_len,
double* out_result);

/*!
* \brief Make sparse prediction for a new dataset in CSC format. Currently only used for feature contributions.
* \note
* The outputs are pre-allocated, as they can vary for each invocation, but the shape should be the same:
* - for feature contributions, the shape of sparse matrix will be ``num_class * num_data * (num_feature + 1)``.
* The output indptr_type for the sparse matrix will be the same as the given input indptr_type.
* \param handle Handle of booster
* \param col_ptr Pointer to column headers
* \param col_ptr_type Type of ``col_ptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64``
* \param indices Pointer to row indices
* \param data Pointer to the data space
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param ncol_ptr Number of columns in the matrix + 1
* \param nelem Number of nonzero elements in the matrix
* \param num_row Number of rows
* \param predict_type What should be predicted
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output indices and data
* \param[out] out_col_ptr Pointer to output column headers
* \param[out] out_indices Pointer to sparse indices
* \param[out] out_data Pointer to sparse data space
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseForCSC(BoosterHandle handle,
const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len,
void** out_col_ptr,
int32_t** out_indices,
void** out_data);

/*!
* \brief Make prediction for a new dataset.
* \note
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <unordered_map>

#if (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_AMD64))) || defined(__INTEL_COMPILER) || MM_PREFETCH
#include <xmmintrin.h>
Expand Down Expand Up @@ -58,6 +59,9 @@ typedef int32_t comm_size_t;
using PredictFunction =
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;

using PredictSparseFunction =
std::function<void(const std::vector<std::pair<int, double>>&, std::vector<std::unordered_map<int, double>>& output)>;

typedef void(*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size);


Expand Down
20 changes: 20 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class Tree {
inline int PredictLeafIndexByMap(const std::unordered_map<int, double>& feature_values) const;

inline void PredictContrib(const double* feature_values, int num_features, double* output);
inline void PredictContribByMap(const std::unordered_map<int, double>& feature_values,
int num_features, std::unordered_map<int, double>& output);

/*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; }
Expand Down Expand Up @@ -382,6 +384,12 @@ class Tree {
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const;

void TreeSHAPByMap(const std::unordered_map<int, double>& feature_values,
std::unordered_map<int, double>& phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const;

/*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/
static void ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index);
Expand Down Expand Up @@ -525,6 +533,18 @@ inline void Tree::PredictContrib(const double* feature_values, int num_features,
}
}

inline void Tree::PredictContribByMap(const std::unordered_map<int, double>& feature_values,
int num_features, std::unordered_map<int, double>& output) {
output[num_features] += ExpectedValue();
// Run the recursion with preallocated space for the unique path data
if (num_leaves_ > 1) {
CHECK_GE(max_depth_, 0);
const int max_path_len = max_depth_ + 1;
std::vector<PathElement> unique_path_data(max_path_len*(max_path_len + 1) / 2);
TreeSHAPByMap(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
}
}

inline void Tree::RecomputeLeafDepths(int node, int depth) {
if (node == 0) leaf_depth_.resize(num_leaves());
if (node < 0) {
Expand Down
141 changes: 139 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import ctypes
import os
import sys
import warnings
from tempfile import NamedTemporaryFile
from collections import OrderedDict
Expand Down Expand Up @@ -511,8 +512,8 @@ def predict(self, data, num_iteration=-1,
Returns
-------
result : numpy array
Prediction result.
result : numpy array or scipy.sparse or list[scipy.sparse]
Prediction result, can be sparse for feature contributions (when pred_contrib=True).
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
Expand Down Expand Up @@ -637,6 +638,56 @@ def inner_predict(mat, num_iteration, predict_type, preds=None):
else:
return inner_predict(mat, num_iteration, predict_type)

def __convert_ctypes_to_numpy(self, ctypes_array, ctypes_array_len, array_type):
"""Convert the ctypes array to a numpy array, note memory will still need to be managed."""
array_size = np.dtype(array_type).itemsize * ctypes_array_len
if sys.version_info.major >= 3:
bfm = ctypes.pythonapi.PyMemoryView_FromMemory
bfm.restype = ctypes.py_object
bfm.argtypes = (ctypes.c_void_p, ctypes.c_int, ctypes.c_int)
PyBUF_READ = 0x100
py_buffer = bfm(ctypes_array, array_size, PyBUF_READ)
else:
bfm = ctypes.pythonapi.PyBuffer_FromMemory
bfm.restype = ctypes.py_object
py_buffer = bfm(ctypes_array, array_size)
return np.frombuffer(py_buffer, array_type)

def __create_sparse_native(self, cs, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data,
np_indptr_type, np_data_type, is_csr=True):
# create numpy array from output arrays
data_indices_len = out_shape[0]
indptr_len = out_shape[1]
out_indptr = self.__convert_ctypes_to_numpy(out_ptr_indptr, indptr_len, np_indptr_type)
out_indices = self.__convert_ctypes_to_numpy(out_ptr_indices, data_indices_len, np.int32)
out_data = self.__convert_ctypes_to_numpy(out_ptr_data, data_indices_len, np_data_type)
# break up indptr based on number of rows (note more than one matrix in multiclass case)
per_class_indptr_shape = cs.indptr.shape[0]
# for CSC there is extra column added
if not is_csr:
per_class_indptr_shape += 1
out_indptr_arrays = np.split(out_indptr, out_indptr.shape[0] / per_class_indptr_shape)
# reformat output into a csr or csc matrix or list of csr or csc matrices
cs_output_matrices = []
offset = 0
for cs_indptr in out_indptr_arrays:
matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1]
cs_indices = out_indices[offset + cs_indptr[0]:offset + matrix_indptr_len]
cs_data = out_data[offset + cs_indptr[0]:offset + matrix_indptr_len]
offset += matrix_indptr_len
# same shape as input csr or csc matrix except extra column for expected value
cs_shape = [cs.shape[0], cs.shape[1] + 1]
# note: make sure we copy data as it will be deallocated next
if is_csr:
cs_output_matrices.append(scipy.sparse.csr_matrix((cs_data, cs_indices, cs_indptr), cs_shape, copy=True))
else:
cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape, copy=True))
# free the temporary native indptr, indices, and data
_safe_call(_LIB.LGBM_BoosterFreePredictSparse(out_ptr_indptr, out_ptr_indices, out_ptr_data))
if len(cs_output_matrices) == 1:
return cs_output_matrices[0]
return cs_output_matrices

def __pred_for_csr(self, csr, num_iteration, predict_type):
"""Predict for a CSR data."""
def inner_predict(csr, num_iteration, predict_type, preds=None):
Expand All @@ -652,6 +703,7 @@ def inner_predict(csr, num_iteration, predict_type, preds=None):
ptr_data, type_ptr_data, _ = c_float_array(csr.data)

assert csr.shape[1] <= MAX_INT32
# Note: not sure why this was csr.indices, we shouldn't be modifying the input matrix
csr.indices = csr.indices.astype(np.int32, copy=False)

_safe_call(_LIB.LGBM_BoosterPredictForCSR(
Expand All @@ -673,6 +725,48 @@ def inner_predict(csr, num_iteration, predict_type, preds=None):
raise ValueError("Wrong length for predict results")
return preds, nrow

def inner_predict_sparse(csr, num_iteration, predict_type):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csr.data)
csr.indices = csr.indices.astype(np.int32, copy=False)
if type_ptr_indptr == C_API_DTYPE_INT32:
np_indptr_type = np.int32
out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
else:
np_indptr_type = np.int64
out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
if type_ptr_data == C_API_DTYPE_FLOAT32:
np_data_type = np.float32
out_ptr_data = ctypes.POINTER(ctypes.c_float)()
else:
np_data_type = np.float64
out_ptr_data = ctypes.POINTER(ctypes.c_double)()
out_shape = np.zeros(2, dtype=np.int64)
_safe_call(_LIB.LGBM_BoosterPredictSparseForCSR(
self.handle,
ptr_indptr,
ctypes.c_int32(type_ptr_indptr),
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data,
ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csr.indptr)),
ctypes.c_int64(len(csr.data)),
ctypes.c_int64(csr.shape[1]),
ctypes.c_int(predict_type),
ctypes.c_int(num_iteration),
c_str(self.pred_parameter),
out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
ctypes.byref(out_ptr_indptr),
ctypes.byref(out_ptr_indices),
ctypes.byref(out_ptr_data)))
matrices = self.__create_sparse_native(csr, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data,
np_indptr_type, np_data_type, is_csr=True)
nrow = len(csr.indptr) - 1
return matrices, nrow

if predict_type == C_API_PREDICT_CONTRIB:
return inner_predict_sparse(csr, num_iteration, predict_type)
nrow = len(csr.indptr) - 1
if nrow > MAX_INT32:
sections = [0] + list(np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32)) + [nrow]
Expand All @@ -688,8 +782,51 @@ def inner_predict(csr, num_iteration, predict_type, preds=None):
else:
return inner_predict(csr, num_iteration, predict_type)


def __pred_for_csc(self, csc, num_iteration, predict_type):
"""Predict for a CSC data."""
def inner_predict_sparse(csc, num_iteration, predict_type):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csc.data)
csc.indices = csc.indices.astype(np.int32, copy=False)
if type_ptr_indptr == C_API_DTYPE_INT32:
np_indptr_type = np.int32
out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
else:
np_indptr_type = np.int64
out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
if type_ptr_data == C_API_DTYPE_FLOAT32:
np_data_type = np.float32
out_ptr_data = ctypes.POINTER(ctypes.c_float)()
else:
np_data_type = np.float64
out_ptr_data = ctypes.POINTER(ctypes.c_double)()
out_shape = np.zeros(2, dtype=np.int64)
_safe_call(_LIB.LGBM_BoosterPredictSparseForCSC(
self.handle,
ptr_indptr,
ctypes.c_int32(type_ptr_indptr),
csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data,
ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csc.indptr)),
ctypes.c_int64(len(csc.data)),
ctypes.c_int64(csc.shape[0]),
ctypes.c_int(predict_type),
ctypes.c_int(num_iteration),
c_str(self.pred_parameter),
out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
ctypes.byref(out_ptr_indptr),
ctypes.byref(out_ptr_indices),
ctypes.byref(out_ptr_data)))
matrices = self.__create_sparse_native(csc, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data,
np_indptr_type, np_data_type, is_csr=False)
nrow = csc.shape[0]
return matrices, nrow

if predict_type == C_API_PREDICT_CONTRIB:
return inner_predict_sparse(csc, num_iteration, predict_type)
nrow = csc.shape[0]
if nrow > MAX_INT32:
return self.__pred_for_csr(csc.tocsr(), num_iteration, predict_type)
Expand Down
Loading

0 comments on commit b6c6af1

Please sign in to comment.