From 43407feba5145597efdb90c38d11e1886b621477 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 16 Feb 2022 20:19:15 -0600 Subject: [PATCH 1/3] factor out lgb.check.obj() --- R-package/R/lgb.cv.R | 3 +-- R-package/R/lgb.train.R | 3 +-- R-package/tests/testthat/test_basic.R | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index abe72220c9e4..fdc93c2f2913 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -127,7 +127,7 @@ lgb.cv <- function(params = list() params <- lgb.check.wrapper_param( main_param_name = "objective" , params = params - , alternative_kwarg_value = NULL + , alternative_kwarg_value = obj ) params <- lgb.check.wrapper_param( main_param_name = "early_stopping_round" @@ -137,7 +137,6 @@ lgb.cv <- function(params = list() early_stopping_rounds <- params[["early_stopping_round"]] # extract any function objects passed for objective or metric - params <- lgb.check.obj(params = params, obj = obj) fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index eebf66ba405f..89b12d0bb03d 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -95,7 +95,7 @@ lgb.train <- function(params = list(), params <- lgb.check.wrapper_param( main_param_name = "objective" , params = params - , alternative_kwarg_value = NULL + , alternative_kwarg_value = obj ) params <- lgb.check.wrapper_param( main_param_name = "early_stopping_round" @@ -105,7 +105,6 @@ lgb.train <- function(params = list(), early_stopping_rounds <- params[["early_stopping_round"]] # extract any function objects passed for objective or metric - params <- lgb.check.obj(params = params, obj = obj) fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 673d8f781e6e..c2d32f66fa47 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -635,6 +635,24 @@ test_that("lgb.train() works as expected with multiple eval metrics", { ) }) +test_that("lgb.train() raises an informative error for unrecognized objectives", { + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + expect_error({ + expect_warning({ + bst <- lgb.train( + data = dtrain + , params = list( + objective_type = "not_a_real_objective" + , verbosity = VERBOSITY + ) + ) + }, regexp = "[LightGBM] [Fatal] Unknown objective type name: not_a_real_objective") + }, regexp = "lgb.Booster: cannot create Booster handle") +}) + test_that("lgb.train() respects parameter aliases for objective", { nrounds <- 3L dtrain <- lgb.Dataset( From ca095aa21ea7be096024b3f927a53c6e21f04c15 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 20 Feb 2022 13:25:54 -0600 Subject: [PATCH 2/3] remove lgb.check.obj() --- R-package/R/utils.R | 58 --------------------------------------------- 1 file changed, 58 deletions(-) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 86b3624f482f..c89bfe9fb0b2 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -117,64 +117,6 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na } -lgb.check.obj <- function(params) { - - # List known objectives in a vector - OBJECTIVES <- c( - "regression" - , "regression_l1" - , "regression_l2" - , "mean_squared_error" - , "mse" - , "l2_root" - , "root_mean_squared_error" - , "rmse" - , "mean_absolute_error" - , "mae" - , "quantile" - , "huber" - , "fair" - , "poisson" - , "binary" - , "lambdarank" - , "multiclass" - , "softmax" - , "multiclassova" - , "multiclass_ova" - , "ova" - , "ovr" - , "xentropy" - , "cross_entropy" - , "xentlambda" - , "cross_entropy_lambda" - , "mean_absolute_percentage_error" - , "mape" - , "gamma" - , "tweedie" - , "rank_xendcg" - , "xendcg" - , "xe_ndcg" - , "xe_ndcg_mart" - , "xendcg_mart" - ) - - if (is.null(params$objective)) { - stop("lgb.check.obj: objective should be a character or a function") - } - - if (is.character(params$objective)) { - - if (!(params$objective %in% OBJECTIVES)) { - - stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")") - - } - - } - - return(params) - -} # [description] # Take any character values from eval and store them in params$metric. From 7171599cbec9c0b577798df0418e1c94192934b7 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 20 Feb 2022 13:32:33 -0600 Subject: [PATCH 3/3] add test on lgb.cv() --- R-package/tests/testthat/test_basic.R | 32 +++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index c27cbae06af3..ab5accab6144 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -520,6 +520,22 @@ test_that("lgb.cv() respects showsd argument", { expect_identical(evals_no_showsd[["eval_err"]], list()) }) +test_that("lgb.cv() raises an informative error for unrecognized objectives", { + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + expect_error({ + bst <- lgb.cv( + data = dtrain + , params = list( + objective_type = "not_a_real_objective" + , verbosity = VERBOSITY + ) + ) + }, regexp = "Unknown objective type name: not_a_real_objective") +}) + test_that("lgb.cv() respects parameter aliases for objective", { nrounds <- 3L nfold <- 4L @@ -669,16 +685,14 @@ test_that("lgb.train() raises an informative error for unrecognized objectives", , label = train$label ) expect_error({ - expect_warning({ - bst <- lgb.train( - data = dtrain - , params = list( - objective_type = "not_a_real_objective" - , verbosity = VERBOSITY - ) + bst <- lgb.train( + data = dtrain + , params = list( + objective_type = "not_a_real_objective" + , verbosity = VERBOSITY ) - }, regexp = "[LightGBM] [Fatal] Unknown objective type name: not_a_real_objective") - }, regexp = "lgb.Booster: cannot create Booster handle") + ) + }, regexp = "Unknown objective type name: not_a_real_objective") }) test_that("lgb.train() respects parameter aliases for objective", {