Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Mar 25, 2020
1 parent 0800f67 commit 68e7840
Showing 1 changed file with 31 additions and 55 deletions.
86 changes: 31 additions & 55 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ train <- agaricus.train
test <- agaricus.test

TOLERANCE <- 1e-6
set.seed(708L)

# [description] Every time this function is called, it adds 0.1
# to an accumulator then returns the current value.
Expand Down Expand Up @@ -788,7 +789,6 @@ test_that("If first_metric_only is TRUE, lgb.train() decides to stop early based
test_that("lgb.train() works when a mixture of functions and strings are passed to eval", {
set.seed(708L)
nrounds <- 10L
early_stopping_rounds <- 3L
increasing_metric_starting_value <- get(ACCUMULATOR_NAME, envir = .GlobalEnv)
bst <- lgb.train(
params = list(
Expand Down Expand Up @@ -818,51 +818,8 @@ test_that("lgb.train() works when a mixture of functions and strings are passed

# the difference metrics shouldn't have been mixed up with each other
results <- bst$record_evals[["valid1"]]
expect_true(abs(results[["rmse"]][["eval"]][[1L]] - 0.9278173) < TOLERANCE)
expect_true(abs(results[["l2"]][["eval"]][[1L]] - 0.8608449) < TOLERANCE)
expected_increasing_metric <- increasing_metric_starting_value + 0.1
expect_true(
abs(
results[["increasing_metric"]][["eval"]][[1L]] - expected_increasing_metric
) < TOLERANCE
)
expect_true(abs(results[["constant_metric"]][["eval"]][[1L]] - CONSTANT_METRIC_VALUE) < TOLERANCE)

})

test_that("lgb.train() works when a character vector is passed to eval", {
set.seed(708L)
nrounds <- 10L
early_stopping_rounds <- 3L
increasing_metric_starting_value <- get(ACCUMULATOR_NAME, envir = .GlobalEnv)
bst <- lgb.train(
params = list(
objective = "binary"
, metric = "None"
)
, data = DTRAIN_RANDOM_CLASSIFICATION
, nrounds = nrounds
, valids = list(
"valid1" = DVALID_RANDOM_CLASSIFICATION
)
, eval = c(
"binary_error"
, "binary_logloss"
)
)

# all 4 metrics should have been used
expect_named(
bst$record_evals[["valid1"]]
, expected = c("rmse", "l2", "increasing_metric", "constant_metric")
, ignore.order = TRUE
, ignore.case = FALSE
)

# the difference metrics shouldn't have been mixed up with each other
results <- bst$record_evals[["valid1"]]
expect_true(abs(results[["rmse"]][["eval"]][[1L]] - 0.9278173) < TOLERANCE)
expect_true(abs(results[["l2"]][["eval"]][[1L]] - 0.8608449) < TOLERANCE)
expect_true(abs(results[["rmse"]][["eval"]][[1L]] - 1.105012) < TOLERANCE)
expect_true(abs(results[["l2"]][["eval"]][[1L]] - 1.221051) < TOLERANCE)
expected_increasing_metric <- increasing_metric_starting_value + 0.1
expect_true(
abs(
Expand All @@ -887,7 +844,6 @@ test_that("lgb.train() works when a list of strings or a character vector is pas

set.seed(708L)
nrounds <- 10L
early_stopping_rounds <- 3L
increasing_metric_starting_value <- get(ACCUMULATOR_NAME, envir = .GlobalEnv)
bst <- lgb.train(
params = list(
Expand All @@ -913,18 +869,17 @@ test_that("lgb.train() works when a list of strings or a character vector is pas
# the difference metrics shouldn't have been mixed up with each other
results <- bst$record_evals[["valid1"]]
if ("binary_error" %in% unlist(eval_variation)) {
expect_true(abs(results[["binary_error"]][["eval"]][[1L]] - 0.5135135) < TOLERANCE)
expect_true(abs(results[["binary_error"]][["eval"]][[1L]] - 0.4864865) < TOLERANCE)
}
if ("binary_logloss" %in% unlist(eval_variation)) {
expect_true(abs(results[["binary_logloss"]][["eval"]][[1L]] - 0.6992222) < TOLERANCE)
expect_true(abs(results[["binary_logloss"]][["eval"]][[1L]] - 0.6932548) < TOLERANCE)
}
}
})

test_that("lgb.train() works when you specify both 'metric' and 'eval' with strings", {
set.seed(708L)
nrounds <- 10L
early_stopping_rounds <- 3L
increasing_metric_starting_value <- get(ACCUMULATOR_NAME, envir = .GlobalEnv)
bst <- lgb.train(
params = list(
Expand All @@ -949,14 +904,13 @@ test_that("lgb.train() works when you specify both 'metric' and 'eval' with stri

# the difference metrics shouldn't have been mixed up with each other
results <- bst$record_evals[["valid1"]]
expect_true(abs(results[["binary_error"]][["eval"]][[1L]] - 0.5135135) < TOLERANCE)
expect_true(abs(results[["binary_logloss"]][["eval"]][[1L]] - 0.6992222) < TOLERANCE)
expect_true(abs(results[["binary_error"]][["eval"]][[1L]] - 0.4864865) < TOLERANCE)
expect_true(abs(results[["binary_logloss"]][["eval"]][[1L]] - 0.6932548) < TOLERANCE)
})

test_that("lgb.train() works when you specify both 'metric' and 'eval' with strings", {
set.seed(708L)
nrounds <- 10L
early_stopping_rounds <- 3L
increasing_metric_starting_value <- get(ACCUMULATOR_NAME, envir = .GlobalEnv)
bst <- lgb.train(
params = list(
Expand All @@ -981,6 +935,28 @@ test_that("lgb.train() works when you specify both 'metric' and 'eval' with stri

# the difference metrics shouldn't have been mixed up with each other
results <- bst$record_evals[["valid1"]]
expect_true(abs(results[["binary_error"]][["eval"]][[1L]] - 0.5135135) < TOLERANCE)
expect_true(abs(results[["binary_logloss"]][["eval"]][[1L]] - 0.6992222) < TOLERANCE)
expect_true(abs(results[["binary_error"]][["eval"]][[1L]] - 0.4864865) < TOLERANCE)
expect_true(abs(results[["binary_logloss"]][["eval"]][[1L]] - 0.6932548) < TOLERANCE)
})

test_that("lgb.train() works when you give a function for eval", {
set.seed(708L)
nrounds <- 10L
increasing_metric_starting_value <- get(ACCUMULATOR_NAME, envir = .GlobalEnv)
bst <- lgb.train(
params = list(
objective = "binary"
, metric = "None"
)
, data = DTRAIN_RANDOM_CLASSIFICATION
, nrounds = nrounds
, valids = list(
"valid1" = DVALID_RANDOM_CLASSIFICATION
)
, eval = .constant_metric
)

# the difference metrics shouldn't have been mixed up with each other
results <- bst$record_evals[["valid1"]]
expect_true(abs(results[["constant_metric"]][["eval"]][[1L]] - CONSTANT_METRIC_VALUE) < TOLERANCE)
})

0 comments on commit 68e7840

Please sign in to comment.