Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Update remaining internal function calls to use keyword arguments #3617

Merged
merged 7 commits into from
Dec 1, 2020
2 changes: 1 addition & 1 deletion R-package/R/lgb.convert_with_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ lgb.convert_with_rules <- function(data, rules = NULL) {
}
if (is_data_table) {
data.table::set(
data
x = data
, j = col_name
, value = unname(rules[[col_name]][data[[col_name]]])
)
Expand Down
56 changes: 28 additions & 28 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ lgb.cv <- function(params = list()
if (is.null(label)) {
stop("'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'")
}
data <- lgb.Dataset(data, label = label)
data <- lgb.Dataset(data = data, label = label)
}

# Setup temporary variables
params <- append(params, list(...))
params$verbose <- verbose
params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval)
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.eval(params = params, eval = eval)
fobj <- NULL
eval_functions <- list(NULL)

Expand Down Expand Up @@ -175,7 +175,7 @@ lgb.cv <- function(params = list()

# Check for weights
if (!is.null(weight)) {
data$setinfo("weight", weight)
data$setinfo(name = "weight", info = weight)
}

# Update parameters with parsed parameters
Expand Down Expand Up @@ -220,21 +220,21 @@ lgb.cv <- function(params = list()
nfold = nfold
, nrows = nrow(data)
, stratified = stratified
, label = getinfo(data, "label")
, group = getinfo(data, "group")
, label = getinfo(dataset = data, name = "label")
, group = getinfo(dataset = data, name = "group")
, params = params
)

}

# Add printing log callback
if (verbose > 0L && eval_freq > 0L) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
callbacks <- add.cb(cb_list = callbacks, cb = cb.print.evaluation(eval_freq))
}

# Add evaluation log callback
if (record) {
callbacks <- add.cb(callbacks, cb.record.evaluation())
callbacks <- add.cb(cb_list = callbacks, cb = cb.record.evaluation())
}

# Did user pass parameters that indicate they want to use early stopping?
Expand Down Expand Up @@ -267,8 +267,8 @@ lgb.cv <- function(params = list()
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
cb_list = callbacks
, cb = cb.early.stop(
stopping_rounds = early_stopping_rounds
, first_metric_only = isTRUE(params[["first_metric_only"]])
, verbose = verbose
Expand All @@ -295,8 +295,8 @@ lgb.cv <- function(params = list()
if (folds_have_group) {
test_indices <- folds[[k]]$fold
test_group_indices <- folds[[k]]$group
test_groups <- getinfo(data, "group")[test_group_indices]
train_groups <- getinfo(data, "group")[-test_group_indices]
test_groups <- getinfo(dataset = data, name = "group")[test_group_indices]
train_groups <- getinfo(dataset = data, name = "group")[-test_group_indices]
} else {
test_indices <- folds[[k]]
}
Expand All @@ -305,32 +305,32 @@ lgb.cv <- function(params = list()
# set up test set
indexDT <- data.table::data.table(
indices = test_indices
, weight = getinfo(data, "weight")[test_indices]
, init_score = getinfo(data, "init_score")[test_indices]
, weight = getinfo(dataset = data, name = "weight")[test_indices]
, init_score = getinfo(dataset = data, name = "init_score")[test_indices]
)
data.table::setorderv(indexDT, "indices", order = 1L)
data.table::setorderv(x = indexDT, cols = "indices", order = 1L)
dtest <- slice(data, indexDT$indices)
setinfo(dtest, "weight", indexDT$weight)
setinfo(dtest, "init_score", indexDT$init_score)
setinfo(dataset = dtest, name = "weight", info = indexDT$weight)
setinfo(dataset = dtest, name = "init_score", info = indexDT$init_score)

# set up training set
indexDT <- data.table::data.table(
indices = train_indices
, weight = getinfo(data, "weight")[train_indices]
, init_score = getinfo(data, "init_score")[train_indices]
, weight = getinfo(data = data, name = "weight")[train_indices]
, init_score = getinfo(data = data, name = "init_score")[train_indices]
)
data.table::setorderv(indexDT, "indices", order = 1L)
data.table::setorderv(x = indexDT, cols = "indices", order = 1L)
dtrain <- slice(data, indexDT$indices)
setinfo(dtrain, "weight", indexDT$weight)
setinfo(dtrain, "init_score", indexDT$init_score)
setinfo(dataset = dtrain, name = "weight", info = indexDT$weight)
setinfo(dataset = dtrain, name = "init_score", info = indexDT$init_score)

if (folds_have_group) {
setinfo(dtest, "group", test_groups)
setinfo(dtrain, "group", train_groups)
setinfo(dataset = dtest, name = "group", info = test_groups)
setinfo(dataset = dtrain, name = "group", info = train_groups)
}

booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid")
booster <- Booster$new(params = params, train_set = dtrain)
booster$add_valid(data = dtest, name = "valid")
return(
list(booster = booster)
)
Expand Down Expand Up @@ -368,7 +368,7 @@ lgb.cv <- function(params = list()
})

# Prepare collection of evaluation results
merged_msg <- lgb.merge.cv.result(msg)
merged_msg <- lgb.merge.cv.result(msg = msg)

# Write evaluation result in environment
env$eval_list <- merged_msg$eval_list
Expand Down Expand Up @@ -446,7 +446,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {

y <- label[rnd_idx]
y <- as.factor(y)
folds <- lgb.stratified.folds(y, nfold)
folds <- lgb.stratified.folds(y = y, k = nfold)

} else {

Expand Down
4 changes: 2 additions & 2 deletions R-package/R/lgb.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ lgb.importance <- function(model, percentage = TRUE) {
}

# Setup importance
tree_dt <- lgb.model.dt.tree(model)
tree_dt <- lgb.model.dt.tree(model = model)

# Extract elements
tree_imp_dt <- tree_dt[
Expand All @@ -54,7 +54,7 @@ lgb.importance <- function(model, percentage = TRUE) {
]

data.table::setnames(
tree_imp_dt
x = tree_imp_dt
, old = "split_feature"
, new = "Feature"
)
Expand Down
26 changes: 16 additions & 10 deletions R-package/R/lgb.interprete.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ lgb.interprete <- function(model,
num_iteration = NULL) {

# Get tree model
tree_dt <- lgb.model.dt.tree(model, num_iteration)
tree_dt <- lgb.model.dt.tree(model = model, num_iteration = num_iteration)

# Check number of classes
num_class <- model$.__enclos_env__$private$num_class
Expand All @@ -59,7 +59,7 @@ lgb.interprete <- function(model,
# Get parsed predictions of data
pred_mat <- t(
model$predict(
data[idxset, , drop = FALSE]
data = data[idxset, , drop = FALSE]
, num_iteration = num_iteration
, predleaf = TRUE
)
Expand All @@ -81,10 +81,10 @@ lgb.interprete <- function(model,
# Sequence over idxset
for (i in seq_along(idxset)) {
tree_interpretation_dt_list[[i]] <- single.row.interprete(
tree_dt
, num_class
, tree_index_mat_list[[i]]
, leaf_index_mat_list[[i]]
tree_dt = tree_dt
, num_class = num_class
, tree_index_mat = tree_index_mat_list[[i]]
, leaf_index_mat = leaf_index_mat_list[[i]]
)
}

Expand Down Expand Up @@ -122,14 +122,20 @@ single.tree.interprete <- function(tree_dt,
# Not null means existing node
this_node <- node_dt[split_index == parent_id, ]
feature_seq <<- c(this_node[["split_feature"]], feature_seq)
leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
leaf_to_root(
parent_id = this_node[["node_parent"]]
, current_value = this_node[["internal_value"]]
)

}

}

# Perform leaf to root conversion
leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
leaf_to_root(
parent_id = leaf_dt[["leaf_parent"]]
, current_value = leaf_dt[["leaf_value"]]
)

data.table::data.table(
Feature = feature_seq
Expand Down Expand Up @@ -191,7 +197,7 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index

if (num_class > 1L) {
data.table::setnames(
next_interp_dt
x = next_interp_dt
, old = "Contribution"
, new = paste("Class", i - 1L)
)
Expand Down Expand Up @@ -221,7 +227,7 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index
for (j in 2L:ncol(tree_interpretation_dt)) {

data.table::set(
tree_interpretation_dt
x = tree_interpretation_dt
, i = which(is.na(tree_interpretation_dt[[j]]))
, j = j
, value = 0.0
Expand Down
6 changes: 3 additions & 3 deletions R-package/R/lgb.plot.interpretation.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,

# Only one class, plot straight away
multiple.tree.plot.interpretation(
tree_interpretation_dt
tree_interpretation = tree_interpretation_dt
, top_n = top_n
, title = NULL
, cex = cex
Expand All @@ -111,12 +111,12 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
# Prepare interpretation, perform T, get the names, and plot straight away
plot_dt <- tree_interpretation_dt[, c(1L, i + 1L), with = FALSE]
data.table::setnames(
plot_dt
x = plot_dt
, old = names(plot_dt)
, new = c("Feature", "Contribution")
)
multiple.tree.plot.interpretation(
plot_dt
tree_interpretation = plot_dt
, top_n = top_n
, title = paste("Class", i - 1L)
, cex = cex
Expand Down