Skip to content

Commit

Permalink
[R-package] remove pre-allocated call_state in C++ calls (#4244)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored May 1, 2021
1 parent b27dcfa commit 26cde5f
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 378 deletions.
50 changes: 0 additions & 50 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ Booster <- R6::R6Class(
if (!lgb.is.null.handle(x = private$handle)) {

# Freeing up handle
call_state <- 0L
.Call(
LGBM_BoosterFree_R
, private$handle
, call_state
)
private$handle <- NULL

Expand Down Expand Up @@ -54,13 +52,11 @@ Booster <- R6::R6Class(
params <- modifyList(params, train_set$get_params())
params_str <- lgb.params2str(params = params)
# Store booster handle
call_state <- 0L
.Call(
LGBM_BoosterCreate_R
, train_set_handle
, params_str
, handle
, call_state
)

# Create private booster information
Expand All @@ -73,12 +69,10 @@ Booster <- R6::R6Class(
if (!is.null(private$init_predictor)) {

# Merge booster
call_state <- 0L
.Call(
LGBM_BoosterMerge_R
, handle
, private$init_predictor$.__enclos_env__$private$handle
, call_state
)

}
Expand All @@ -94,12 +88,10 @@ Booster <- R6::R6Class(
}

# Create booster from model
call_state <- 0L
.Call(
LGBM_BoosterCreateFromModelfile_R
, lgb.c_str(x = modelfile)
, handle
, call_state
)

} else if (!is.null(model_str)) {
Expand All @@ -110,12 +102,10 @@ Booster <- R6::R6Class(
}

# Create booster from model
call_state <- 0L
.Call(
LGBM_BoosterLoadModelFromString_R
, lgb.c_str(x = model_str)
, handle
, call_state
)

} else {
Expand All @@ -141,12 +131,10 @@ Booster <- R6::R6Class(
class(handle) <- "lgb.Booster.handle"
private$handle <- handle
private$num_class <- 1L
call_state <- 0L
.Call(
LGBM_BoosterGetNumClasses_R
, private$handle
, private$num_class
, call_state
)

}
Expand Down Expand Up @@ -188,12 +176,10 @@ Booster <- R6::R6Class(
}

# Add validation data to booster
call_state <- 0L
.Call(
LGBM_BoosterAddValidData_R
, private$handle
, data$.__enclos_env__$private$get_handle()
, call_state
)

# Store private information
Expand All @@ -216,12 +202,10 @@ Booster <- R6::R6Class(
params <- modifyList(params, list(...))
params_str <- lgb.params2str(params = params)

call_state <- 0L
.Call(
LGBM_BoosterResetParameter_R
, private$handle
, params_str
, call_state
)
self$params <- params

Expand Down Expand Up @@ -252,12 +236,10 @@ Booster <- R6::R6Class(
}

# Reset training data on booster
call_state <- 0L
.Call(
LGBM_BoosterResetTrainingData_R
, private$handle
, train_set$.__enclos_env__$private$get_handle()
, call_state
)

# Store private train set
Expand All @@ -272,11 +254,9 @@ Booster <- R6::R6Class(
stop("lgb.Booster.update: cannot update due to null objective function")
}
# Boost iteration from known objective
call_state <- 0L
.Call(
LGBM_BoosterUpdateOneIter_R
, private$handle
, call_state
)

} else {
Expand All @@ -299,14 +279,12 @@ Booster <- R6::R6Class(
}

# Return custom boosting gradient/hessian
call_state <- 0L
.Call(
LGBM_BoosterUpdateOneIterCustom_R
, private$handle
, gpair$grad
, gpair$hess
, length(gpair$grad)
, call_state
)

}
Expand All @@ -324,11 +302,9 @@ Booster <- R6::R6Class(
rollback_one_iter = function() {

# Return one iteration behind
call_state <- 0L
.Call(
LGBM_BoosterRollbackOneIter_R
, private$handle
, call_state
)

# Loop through each iteration
Expand All @@ -344,12 +320,10 @@ Booster <- R6::R6Class(
current_iter = function() {

cur_iter <- 0L
call_state <- 0L
.Call(
LGBM_BoosterGetCurrentIteration_R
, private$handle
, cur_iter
, call_state
)
return(cur_iter)

Expand All @@ -359,12 +333,10 @@ Booster <- R6::R6Class(
upper_bound = function() {

upper_bound <- 0.0
call_state <- 0L
.Call(
LGBM_BoosterGetUpperBoundValue_R
, private$handle
, upper_bound
, call_state
)
return(upper_bound)

Expand All @@ -374,12 +346,10 @@ Booster <- R6::R6Class(
lower_bound = function() {

lower_bound <- 0.0
call_state <- 0L
.Call(
LGBM_BoosterGetLowerBoundValue_R
, private$handle
, lower_bound
, call_state
)
return(lower_bound)

Expand Down Expand Up @@ -477,14 +447,12 @@ Booster <- R6::R6Class(
}

# Save booster model
call_state <- 0L
.Call(
LGBM_BoosterSaveModel_R
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, lgb.c_str(x = filename)
, call_state
)

return(invisible(self))
Expand All @@ -504,7 +472,6 @@ Booster <- R6::R6Class(
buf <- raw(buf_len)

# Call buffer
call_state <- 0L
.Call(
LGBM_BoosterSaveModelToString_R
, private$handle
Expand All @@ -513,14 +480,12 @@ Booster <- R6::R6Class(
, buf_len
, act_len
, buf
, call_state
)

# Check for buffer content
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
call_state <- 0L
.Call(
LGBM_BoosterSaveModelToString_R
, private$handle
Expand All @@ -529,7 +494,6 @@ Booster <- R6::R6Class(
, buf_len
, act_len
, buf
, call_state
)
}

Expand All @@ -550,7 +514,6 @@ Booster <- R6::R6Class(
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
call_state <- 0L
.Call(
LGBM_BoosterDumpModel_R
, private$handle
Expand All @@ -559,13 +522,11 @@ Booster <- R6::R6Class(
, buf_len
, act_len
, buf
, call_state
)

if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
call_state <- 0L
.Call(
LGBM_BoosterDumpModel_R
, private$handle
Expand All @@ -574,7 +535,6 @@ Booster <- R6::R6Class(
, buf_len
, act_len
, buf
, call_state
)
}

Expand Down Expand Up @@ -674,14 +634,12 @@ Booster <- R6::R6Class(
if (is.null(private$predict_buffer[[data_name]])) {

# Store predictions
call_state <- 0L
npred <- 0L
.Call(
LGBM_BoosterGetNumPredict_R
, private$handle
, as.integer(idx - 1L)
, npred
, call_state
)
private$predict_buffer[[data_name]] <- numeric(npred)

Expand All @@ -691,13 +649,11 @@ Booster <- R6::R6Class(
if (!private$is_predicted_cur_iter[[idx]]) {

# Use buffer
call_state <- 0L
.Call(
LGBM_BoosterGetPredict_R
, private$handle
, as.integer(idx - 1L)
, private$predict_buffer[[data_name]]
, call_state
)
private$is_predicted_cur_iter[[idx]] <- TRUE
}
Expand All @@ -715,26 +671,22 @@ Booster <- R6::R6Class(
buf_len <- as.integer(1024L * 1024L)
act_len <- 0L
buf <- raw(buf_len)
call_state <- 0L
.Call(
LGBM_BoosterGetEvalNames_R
, private$handle
, buf_len
, act_len
, buf
, call_state
)
if (act_len > buf_len) {
buf_len <- act_len
buf <- raw(buf_len)
call_state <- 0L
.Call(
LGBM_BoosterGetEvalNames_R
, private$handle
, buf_len
, act_len
, buf
, call_state
)
}
names <- lgb.encode.char(arr = buf, len = act_len)
Expand Down Expand Up @@ -778,13 +730,11 @@ Booster <- R6::R6Class(

# Create evaluation values
tmp_vals <- numeric(length(private$eval_names))
call_state <- 0L
.Call(
LGBM_BoosterGetEval_R
, private$handle
, as.integer(data_idx - 1L)
, tmp_vals
, call_state
)

# Loop through all evaluation names
Expand Down
Loading

0 comments on commit 26cde5f

Please sign in to comment.