Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] use R standard routine to access read-only ints passed to C++ #4246

Merged
merged 6 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions R-package/src/R_object_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN;

#define R_REAL_PTR(x) (reinterpret_cast<double*> DATAPTR(x))

#define R_AS_INT(x) (*(reinterpret_cast<int*> DATAPTR(x)))

#define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0)

// 64bit pointer
Expand Down
162 changes: 82 additions & 80 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ using LightGBM::Common::Join;
using LightGBM::Common::Split;
using LightGBM::Log;

LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len, size_t str_len) {
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, LGBM_SE actual_len, size_t str_len) {
if (str_len > INT32_MAX) {
Log::Fatal("Don't support large string in R-package");
}
R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
if (R_AS_INT(buf_len) < static_cast<int>(str_len)) {
if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
return dest;
}
auto ptr = R_CHAR_PTR(dest);
Expand Down Expand Up @@ -79,9 +79,9 @@ SEXP LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
LGBM_SE indices,
LGBM_SE data,
LGBM_SE num_indptr,
LGBM_SE nelem,
LGBM_SE num_row,
SEXP num_indptr,
SEXP nelem,
SEXP num_row,
LGBM_SE parameters,
LGBM_SE reference,
LGBM_SE out) {
Expand All @@ -90,9 +90,9 @@ SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
const int* p_indices = R_INT_PTR(indices);
const double* p_data = R_REAL_PTR(data);

int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
DatasetHandle handle = nullptr;
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
Expand All @@ -102,14 +102,14 @@ SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
}

SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data,
LGBM_SE num_row,
LGBM_SE num_col,
SEXP num_row,
SEXP num_col,
LGBM_SE parameters,
LGBM_SE reference,
LGBM_SE out) {
R_API_BEGIN();
int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
double* p_mat = R_REAL_PTR(data);
DatasetHandle handle = nullptr;
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
Expand All @@ -120,11 +120,11 @@ SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data,

SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
LGBM_SE used_row_indices,
LGBM_SE len_used_row_indices,
SEXP len_used_row_indices,
LGBM_SE parameters,
LGBM_SE out) {
R_API_BEGIN();
int len = R_AS_INT(len_used_row_indices);
int len = Rf_asInteger(len_used_row_indices);
std::vector<int> idxvec(len);
// convert from one-based to zero-based index
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Expand Down Expand Up @@ -154,7 +154,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
}

SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
LGBM_SE buf_len,
SEXP buf_len,
LGBM_SE actual_len,
LGBM_SE feature_names) {
R_API_BEGIN();
Expand Down Expand Up @@ -202,9 +202,9 @@ SEXP LGBM_DatasetFree_R(LGBM_SE handle) {
SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
LGBM_SE field_name,
LGBM_SE field_data,
LGBM_SE num_element) {
SEXP num_element) {
R_API_BEGIN();
int len = static_cast<int>(R_AS_INT(num_element));
int len = static_cast<int>(Rf_asInteger(num_element));
const char* name = R_CHAR_PTR(field_name);
if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len);
Expand Down Expand Up @@ -387,10 +387,10 @@ SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) {
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
LGBM_SE grad,
LGBM_SE hess,
LGBM_SE len) {
SEXP len) {
int is_finished = 0;
R_API_BEGIN();
int int_len = R_AS_INT(len);
int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len);
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
for (int j = 0; j < int_len; ++j) {
Expand Down Expand Up @@ -433,7 +433,7 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
}

SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
LGBM_SE buf_len,
SEXP buf_len,
LGBM_SE actual_len,
LGBM_SE eval_names) {
R_API_BEGIN();
Expand Down Expand Up @@ -464,83 +464,83 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
}

SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
LGBM_SE data_idx,
SEXP data_idx,
LGBM_SE out_result) {
R_API_BEGIN();
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
double* ptr_ret = R_REAL_PTR(out_result);
int out_len;
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
CHECK_EQ(out_len, len);
R_API_END();
}

SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
LGBM_SE data_idx,
SEXP data_idx,
LGBM_SE out) {
R_API_BEGIN();
int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len));
R_INT_PTR(out)[0] = static_cast<int>(len);
R_API_END();
}

SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle,
LGBM_SE data_idx,
SEXP data_idx,
LGBM_SE out_result) {
R_API_BEGIN();
double* ptr_ret = R_REAL_PTR(out_result);
int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
R_API_END();
}

int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
int pred_type = C_API_PREDICT_NORMAL;
if (R_AS_INT(is_rawscore)) {
if (Rf_asInteger(is_rawscore)) {
pred_type = C_API_PREDICT_RAW_SCORE;
}
if (R_AS_INT(is_leafidx)) {
if (Rf_asInteger(is_leafidx)) {
pred_type = C_API_PREDICT_LEAF_INDEX;
}
if (R_AS_INT(is_predcontrib)) {
if (Rf_asInteger(is_predcontrib)) {
pred_type = C_API_PREDICT_CONTRIB;
}
return pred_type;
}

SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_filename,
LGBM_SE data_has_header,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
SEXP data_has_header,
SEXP is_rawscore,
SEXP is_leafidx,
SEXP is_predcontrib,
SEXP start_iteration,
SEXP num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename)));
R_API_END();
}

SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
SEXP num_row,
SEXP is_rawscore,
SEXP is_leafidx,
SEXP is_predcontrib,
SEXP start_iteration,
SEXP num_iteration,
LGBM_SE out_len) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), &len));
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row),
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
R_INT_PTR(out_len)[0] = static_cast<int>(len);
R_API_END();
}
Expand All @@ -549,14 +549,14 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE indptr,
LGBM_SE indices,
LGBM_SE data,
LGBM_SE num_indptr,
LGBM_SE nelem,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
SEXP num_indptr,
SEXP nelem,
SEXP num_row,
SEXP is_rawscore,
SEXP is_leafidx,
SEXP is_predcontrib,
SEXP start_iteration,
SEXP num_iteration,
LGBM_SE parameter,
LGBM_SE out_result) {
R_API_BEGIN();
Expand All @@ -566,78 +566,80 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
const int* p_indices = R_INT_PTR(indices);
const double* p_data = R_REAL_PTR(data);

int64_t nindptr = R_AS_INT(num_indptr);
int64_t ndata = R_AS_INT(nelem);
int64_t nrow = R_AS_INT(num_row);
int64_t nindptr = Rf_asInteger(num_indptr);
int64_t ndata = Rf_asInteger(nelem);
int64_t nrow = Rf_asInteger(num_row);
double* ptr_ret = R_REAL_PTR(out_result);
int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END();
}

SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE data,
LGBM_SE num_row,
LGBM_SE num_col,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
SEXP num_row,
SEXP num_col,
SEXP is_rawscore,
SEXP is_leafidx,
SEXP is_predcontrib,
SEXP start_iteration,
SEXP num_iteration,
LGBM_SE parameter,
LGBM_SE out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);

int32_t nrow = R_AS_INT(num_row);
int32_t ncol = R_AS_INT(num_col);
int32_t nrow = Rf_asInteger(num_row);
int32_t ncol = Rf_asInteger(num_col);

const double* p_mat = R_REAL_PTR(data);
double* ptr_ret = R_REAL_PTR(out_result);
int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));

R_API_END();
}

SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
SEXP num_iteration,
SEXP feature_importance_type,
LGBM_SE filename) {
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_CHAR_PTR(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), R_CHAR_PTR(filename)));
R_API_END();
}

SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str) {
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}

SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str) {
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}
Expand Down
Loading