Skip to content

Commit

Permalink
[R-package] move creation of character vectors in some methods to C++…
Browse files Browse the repository at this point in the history
… side (#4256)

* [R-package] move creation of character vectors in some methods to C++ side

* convert LGBM_BoosterGetEvalNames_R

* convert LGBM_BoosterDumpModel_R and LGBM_BoosterSaveModelToString_R

* remove debugging code

* update docs

* remove comment

* add handling for larger model strings

* handle large strings in feature and eval names

* got long feature names working

* more fixes

* linting

* resize

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* stricter test

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jameslamb and StrikerRUS authored May 9, 2021
1 parent a421217 commit c1d2dbe
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 194 deletions.
85 changes: 8 additions & 77 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -466,40 +466,14 @@ Booster <- R6::R6Class(
num_iteration <- self$best_iter
}

# Create buffer
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)

# Call buffer
.Call(
model_str <- .Call(
LGBM_BoosterSaveModelToString_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)

# Check for buffer content
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_BoosterSaveModelToString_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)
}

return(
lgb.encode.char(arr = buf, len = act_len)
)
return(model_str)

},

Expand All @@ -511,36 +485,14 @@ Booster <- R6::R6Class(
num_iteration <- self$best_iter
}

buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
.Call(
model_str <- .Call(
LGBM_BoosterDumpModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)

if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_BoosterDumpModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, buf_len
, act_len
, buf
)
}

return(
lgb.encode.char(arr = buf, len = act_len)
)
return(model_str)

},

Expand Down Expand Up @@ -666,41 +618,20 @@ Booster <- R6::R6Class(

# Check for evaluation names emptiness
if (is.null(private$eval_names)) {

# Get evaluation names
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
.Call(
eval_names <- .Call(
LGBM_BoosterGetEvalNames_R
, private$handle
, buf_len
, act_len
, buf
)
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_BoosterGetEvalNames_R
, private$handle
, buf_len
, act_len
, buf
)
}
names <- lgb.encode.char(arr = buf, len = act_len)

# Check names' length
if (nchar(names) > 0L) {
if (length(eval_names) > 0L) {

# Parse and store privately names
names <- strsplit(names, "\t")[[1L]]
private$eval_names <- names
private$eval_names <- eval_names

# some metrics don't map cleanly to metric names, for example "ndcg@1" is just the
# ndcg metric evaluated at the first "query result" in learning-to-rank
metric_names <- gsub("@.*", "", names)
metric_names <- gsub("@.*", "", eval_names)
private$higher_better_inner_eval <- .METRICS_HIGHER_BETTER()[metric_names]

}
Expand Down
23 changes: 1 addition & 22 deletions R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -369,31 +369,10 @@ Dataset <- R6::R6Class(

# Check for handle
if (!lgb.is.null.handle(x = private$handle)) {

# Get feature names and write them
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
.Call(
private$colnames <- .Call(
LGBM_DatasetGetFeatureNames_R
, private$handle
, buf_len
, act_len
, buf
)
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
.Call(
LGBM_DatasetGetFeatureNames_R
, private$handle
, buf_len
, act_len
, buf
)
}
cnames <- lgb.encode.char(arr = buf, len = act_len)
private$colnames <- as.character(base::strsplit(cnames, "\t")[[1L]])
return(private$colnames)

} else if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) {
Expand Down
7 changes: 0 additions & 7 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@ lgb.is.null.handle <- function(x) {
return(is.null(x) || is.na(x))
}

lgb.encode.char <- function(arr, len) {
if (!is.raw(arr)) {
stop("lgb.encode.char: Can only encode from raw type")
}
return(rawToChar(arr[seq_len(len)]))
}

# [description] Get the most recent error stored on the C++ side and raise it
# as an R error.
lgb.last_error <- function() {
Expand Down
2 changes: 0 additions & 2 deletions R-package/src/R_object_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN;

#define DATAPTR(x) ((reinterpret_cast<SEXPREC_ALIGN*>(x)) + 1)

#define R_CHAR_PTR(x) (reinterpret_cast<char*>DATAPTR(x))

#define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0)

// 64bit pointer
Expand Down
128 changes: 82 additions & 46 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,9 @@
return R_NilValue; \
}

using LightGBM::Common::Join;
using LightGBM::Common::Split;
using LightGBM::Log;

LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, SEXP actual_len, size_t str_len) {
if (str_len > INT32_MAX) {
Log::Fatal("Don't support large string in R-package");
}
INTEGER(actual_len)[0] = static_cast<int>(str_len);
if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
return dest;
}
auto ptr = R_CHAR_PTR(dest);
std::memcpy(ptr, src, str_len);
return dest;
}

SEXP LGBM_GetLastError_R() {
SEXP out;
out = PROTECT(Rf_allocVector(STRSXP, 1));
Expand Down Expand Up @@ -153,10 +139,8 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
R_API_END();
}

SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
SEXP buf_len,
SEXP actual_len,
LGBM_SE feature_names) {
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
SEXP feature_names;
R_API_BEGIN();
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
Expand All @@ -175,10 +159,29 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
// if any feature names were larger than allocated size,
// allow for a larger size and try again
if (required_string_size > reserved_string_size) {
for (int i = 0; i < len; ++i) {
names[i].resize(required_string_size);
ptr_names[i] = names[i].data();
}
CHECK_CALL(
LGBM_DatasetGetFeatureNames(
R_GET_PTR(handle),
len,
&out_len,
required_string_size,
&required_string_size,
ptr_names.data()));
}
CHECK_EQ(len, out_len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_str = Join<char*>(ptr_names, "\t");
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
feature_names = PROTECT(Rf_allocVector(STRSXP, len));
for (int i = 0; i < len; ++i) {
SET_STRING_ELT(feature_names, i, Rf_mkChar(ptr_names[i]));
}
UNPROTECT(1);
return feature_names;
R_API_END();
}

Expand Down Expand Up @@ -432,10 +435,8 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
R_API_END();
}

SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
SEXP buf_len,
SEXP actual_len,
LGBM_SE eval_names) {
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
SEXP eval_names;
R_API_BEGIN();
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
Expand All @@ -456,10 +457,29 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()));
// if any eval names were larger than allocated size,
// allow for a larger size and try again
if (required_string_size > reserved_string_size) {
for (int i = 0; i < len; ++i) {
names[i].resize(required_string_size);
ptr_names[i] = names[i].data();
}
CHECK_CALL(
LGBM_BoosterGetEvalNames(
R_GET_PTR(handle),
len,
&out_len,
required_string_size,
&required_string_size,
ptr_names.data()));
}
CHECK_EQ(out_len, len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_names = Join<char*>(ptr_names, "\t");
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
eval_names = PROTECT(Rf_allocVector(STRSXP, len));
for (int i = 0; i < len; ++i) {
SET_STRING_ELT(eval_names, i, Rf_mkChar(ptr_names[i]));
}
UNPROTECT(1);
return eval_names;
R_API_END();
}

Expand Down Expand Up @@ -616,31 +636,47 @@ SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,

SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
SEXP actual_len,
LGBM_SE out_str) {
SEXP feature_importance_type) {
SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0;
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
int64_t buf_len = 1024 * 1024;
int64_t num_iter = Rf_asInteger(num_iteration);
int64_t importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1);
return model_str;
R_API_END();
}

SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP buffer_len,
SEXP actual_len,
LGBM_SE out_str) {
SEXP feature_importance_type) {
SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0;
int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
int64_t buf_len = 1024 * 1024;
int64_t num_iter = Rf_asInteger(num_iteration);
int64_t importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1);
return model_str;
R_API_END();
}

Expand All @@ -652,7 +688,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_DatasetCreateFromMat_R" , (DL_FUNC) &LGBM_DatasetCreateFromMat_R , 6},
{"LGBM_DatasetGetSubset_R" , (DL_FUNC) &LGBM_DatasetGetSubset_R , 5},
{"LGBM_DatasetSetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R , 2},
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 4},
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 1},
{"LGBM_DatasetSaveBinary_R" , (DL_FUNC) &LGBM_DatasetSaveBinary_R , 2},
{"LGBM_DatasetFree_R" , (DL_FUNC) &LGBM_DatasetFree_R , 1},
{"LGBM_DatasetSetField_R" , (DL_FUNC) &LGBM_DatasetSetField_R , 4},
Expand All @@ -676,7 +712,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
{"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
{"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 4},
{"LGBM_BoosterGetEvalNames_R" , (DL_FUNC) &LGBM_BoosterGetEvalNames_R , 1},
{"LGBM_BoosterGetEval_R" , (DL_FUNC) &LGBM_BoosterGetEval_R , 3},
{"LGBM_BoosterGetNumPredict_R" , (DL_FUNC) &LGBM_BoosterGetNumPredict_R , 3},
{"LGBM_BoosterGetPredict_R" , (DL_FUNC) &LGBM_BoosterGetPredict_R , 3},
Expand All @@ -685,8 +721,8 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{NULL, NULL, 0}
};

Expand Down
Loading

0 comments on commit c1d2dbe

Please sign in to comment.