Skip to content

Commit

Permalink
[R-package] avoid misleading warnings when using interaction constrai…
Browse files Browse the repository at this point in the history
…nts (fixes #4108) (#4232)
  • Loading branch information
jameslamb authored Apr 28, 2021
1 parent 086f078 commit fa6d356
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
10 changes: 9 additions & 1 deletion R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ lgb.cv <- function(params = list()
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L

# pop interaction_constraints off of params. It needs some preprocessing on the
# R side before being passed into the Dataset object
interaction_constraints <- params[["interaction_constraints"]]
params["interaction_constraints"] <- NULL

# Construct datasets, if needed
data$update_params(params = params)
data$construct()
Expand All @@ -177,7 +182,10 @@ lgb.cv <- function(params = list()
} else if (!is.null(data$get_colnames())) {
cnames <- data$get_colnames()
}
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params = params, column_names = cnames)
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(
interaction_constraints = interaction_constraints
, column_names = cnames
)

# Check for weights
if (!is.null(weight)) {
Expand Down
7 changes: 6 additions & 1 deletion R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ lgb.train <- function(params = list(),
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L

# pop interaction_constraints off of params. It needs some preprocessing on the
# R side before being passed into the Dataset object
interaction_constraints <- params[["interaction_constraints"]]
params["interaction_constraints"] <- NULL

# Construct datasets, if needed
data$update_params(params = params)
data$construct()
Expand All @@ -156,7 +161,7 @@ lgb.train <- function(params = list(),
cnames <- data$get_colnames()
}
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(
params = params
interaction_constraints = interaction_constraints
, column_names = cnames
)

Expand Down
10 changes: 5 additions & 5 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,21 @@ lgb.params2str <- function(params, ...) {

}

lgb.check_interaction_constraints <- function(params, column_names) {
lgb.check_interaction_constraints <- function(interaction_constraints, column_names) {

# Convert interaction constraints to feature numbers
string_constraints <- list()

if (!is.null(params[["interaction_constraints"]])) {
if (!is.null(interaction_constraints)) {

if (!methods::is(params[["interaction_constraints"]], "list")) {
if (!methods::is(interaction_constraints, "list")) {
stop("interaction_constraints must be a list")
}
if (!all(sapply(params[["interaction_constraints"]], function(x) {is.character(x) || is.numeric(x)}))) {
if (!all(sapply(interaction_constraints, function(x) {is.character(x) || is.numeric(x)}))) {
stop("every element in interaction_constraints must be a character vector or numeric vector")
}

for (constraint in params[["interaction_constraints"]]) {
for (constraint in interaction_constraints) {

# Check for character name
if (is.character(constraint)) {
Expand Down

0 comments on commit fa6d356

Please sign in to comment.