Skip to content

Commit

Permalink
[Python] / [R] add start_iteration to python predict interface (fix #…
Browse files Browse the repository at this point in the history
…3058) (#3272)

* [python] add start_iteration to python predict interface (#3058)

* Apply suggestions from code review

* Update lightgbm_R.h

* Apply suggestions from code review

* Apply suggestions from code review

* fix R interface

* update R documentation

Co-authored-by: Guolin Ke <guolin.ke@outlook.com>
  • Loading branch information
shiyu1994 and guolinke authored Aug 6, 2020
1 parent 083b02a commit 82e2ff7
Show file tree
Hide file tree
Showing 23 changed files with 438 additions and 108 deletions.
18 changes: 16 additions & 2 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ Booster <- R6::R6Class(

# Predict on new data
predict = function(data,
start_iteration = NULL,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
Expand All @@ -494,10 +495,14 @@ Booster <- R6::R6Class(
if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
# Check if start iteration is non existent
if (is.null(start_iteration)) {
start_iteration <- 0L
}

# Predict on new data
predictor <- Predictor$new(private$handle, ...)
predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape)
predictor$predict(data, start_iteration, num_iteration, rawscore, predleaf, predcontrib, header, reshape)

},

Expand Down Expand Up @@ -698,7 +703,14 @@ Booster <- R6::R6Class(
#' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#' Limit number of iterations in the prediction.
#' If None, if the best iteration exists and start_iteration is None or <= 0, the
#' best iteration is used; otherwise, all iterations from start_iteration are used.
#' If <= 0, all iterations from start_iteration are used (no limits).
#' @param rawscore whether the prediction should be returned in the for of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
#' for logistic regression would result in predictions for log-odds instead of probabilities.
Expand Down Expand Up @@ -740,6 +752,7 @@ Booster <- R6::R6Class(
#' @export
predict.lgb.Booster <- function(object,
data,
start_iteration = NULL,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
Expand All @@ -756,6 +769,7 @@ predict.lgb.Booster <- function(object,
# Return booster predictions
object$predict(
data
, start_iteration
, num_iteration
, rawscore
, predleaf
Expand Down
9 changes: 9 additions & 0 deletions R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Predictor <- R6::R6Class(

# Predict from data
predict = function(data,
start_iteration = NULL,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
Expand All @@ -87,6 +88,10 @@ Predictor <- R6::R6Class(
if (is.null(num_iteration)) {
num_iteration <- -1L
}
# Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
if (is.null(start_iteration)) {
start_iteration <- 0L
}

# Set temporary variable
num_row <- 0L
Expand All @@ -108,6 +113,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore)
, as.integer(predleaf)
, as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration)
, private$params
, lgb.c_str(tmp_filename)
Expand All @@ -134,6 +140,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore)
, as.integer(predleaf)
, as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration)
)

Expand All @@ -156,6 +163,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore)
, as.integer(predleaf)
, as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration)
, private$params
)
Expand All @@ -178,6 +186,7 @@ Predictor <- R6::R6Class(
, as.integer(rawscore)
, as.integer(predleaf)
, as.integer(predcontrib)
, as.integer(start_iteration)
, as.integer(num_iteration)
, private$params
)
Expand Down
11 changes: 10 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 12 additions & 8 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,15 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_AS_INT(data_has_header), pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename)));
R_API_END();
}
Expand All @@ -558,14 +559,15 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE out_len,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(num_iteration), &len));
pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), &len));
R_INT_PTR(out_len)[0] = static_cast<int>(len);
R_API_END();
}
Expand All @@ -580,6 +582,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
Expand All @@ -599,7 +602,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
nrow, pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END();
}

Expand All @@ -610,6 +613,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
Expand All @@ -625,7 +629,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));

R_API_END();
}
Expand Down Expand Up @@ -706,10 +710,10 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterGetEval_R" , (DL_FUNC) &LGBM_BoosterGetEval_R , 4},
{"LGBM_BoosterGetNumPredict_R" , (DL_FUNC) &LGBM_BoosterGetNumPredict_R , 4},
{"LGBM_BoosterGetPredict_R" , (DL_FUNC) &LGBM_BoosterGetPredict_R , 4},
{"LGBM_BoosterPredictForFile_R" , (DL_FUNC) &LGBM_BoosterPredictForFile_R , 10},
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterPredictForFile_R" , (DL_FUNC) &LGBM_BoosterPredictForFile_R , 11},
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 9},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 15},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 12},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7},
Expand Down
4 changes: 4 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
Expand All @@ -511,6 +512,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R(
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE out_len,
LGBM_SE call_state
Expand Down Expand Up @@ -545,6 +547,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
Expand Down Expand Up @@ -574,6 +577,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE start_iteration,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
Expand Down
60 changes: 60 additions & 0 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,63 @@ test_that("predictions do not fail for integer input", {
pred_double <- predict(fit, X_double)
expect_equal(pred_integer, pred_double)
})

test_that("start_iteration works correctly", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
)
dtest <- lgb.Dataset.create.valid(
dtrain
, agaricus.test$data
, label = agaricus.test$label
)
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 4L
, learning_rate = 0.6
, nrounds = 100L
, objective = "binary"
, save_name = tempfile(fileext = ".model")
, valids = list("test" = dtest)
, early_stopping_rounds = 2L
)
expect_true(lgb.is.Booster(bst))
pred1 <- predict(bst, data = test$data, rawscore = TRUE)
pred_contrib1 <- predict(bst, test$data, predcontrib = TRUE)
pred2 <- rep(0.0, length(pred1))
pred_contrib2 <- rep(0.0, length(pred2))
step <- 11L
end_iter <- 99L
if (bst$best_iter != -1L) {
end_iter <- bst$best_iter - 1L
}
start_iters <- seq(0L, end_iter, by = step)
for (start_iter in start_iters) {
n_iter <- min(c(end_iter - start_iter + 1L, step))
inc_pred <- predict(bst, test$data
, start_iteration = start_iter
, num_iteration = n_iter
, rawscore = TRUE
)
inc_pred_contrib <- bst$predict(test$data
, start_iteration = start_iter
, num_iteration = n_iter
, predcontrib = TRUE
)
pred2 <- pred2 + inc_pred
pred_contrib2 <- pred_contrib2 + inc_pred_contrib
}
expect_equal(pred2, pred1)
expect_equal(pred_contrib2, pred_contrib1)

pred_leaf1 <- predict(bst, test$data, predleaf = TRUE)
pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE)
expect_equal(pred_leaf1, pred_leaf2)
})
8 changes: 8 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,14 @@ Dataset Parameters
Predict Parameters
~~~~~~~~~~~~~~~~~~

- ``start_iteration_predict`` :raw-html:`<a id="start_iteration_predict" title="Permalink to this parameter" href="#start_iteration_predict">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int

- used only in ``prediction`` task

- used to specify from which iteration to start the prediction

- ``<= 0`` means from the first iteration

- ``num_iteration_predict`` :raw-html:`<a id="num_iteration_predict" title="Permalink to this parameter" href="#num_iteration_predict">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int

- used only in ``prediction`` task
Expand Down
5 changes: 3 additions & 2 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;

virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
virtual int NumPredictOneRow(int start_iteration, int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;

/*!
* \brief Prediction for one record, not sigmoid transform
Expand Down Expand Up @@ -284,10 +284,11 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Initial work for the prediction
* \param start_iteration Start index of the iteration to predict
* \param num_iteration number of used iteration
* \param is_pred_contrib
*/
virtual void InitPredict(int num_iteration, bool is_pred_contrib) = 0;
virtual void InitPredict(int start_iteration, int num_iteration, bool is_pred_contrib) = 0;

/*!
* \brief Name of submodel
Expand Down
Loading

0 comments on commit 82e2ff7

Please sign in to comment.