Skip to content

Commit

Permalink
[R-package] Added tests on use of force_col_wise and force_row_wise i…
Browse files Browse the repository at this point in the history
…n training (#2719)
  • Loading branch information
jameslamb authored Jan 31, 2020
1 parent 382e13e commit dec3d79
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
43 changes: 43 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,46 @@ test_that("lgb.train() throws an informative error if 'valids' contains lgb.Data
)
}, regexp = "each element of valids must have a name")
})

test_that("lgb.train() works with force_col_wise and force_row_wise", {
set.seed(1234L)
nrounds <- 10L
dtrain <- lgb.Dataset(
train$data
, label = train$label
)
params <- list(
objective = "binary"
, metric = "binary_error"
, force_col_wise = TRUE
)
bst_colwise <- lgb.train(
params = params
, data = dtrain
, nrounds = nrounds
)

params <- list(
objective = "binary"
, metric = "binary_error"
, force_row_wise = TRUE
)
bst_row_wise <- lgb.train(
params = params
, data = dtrain
, nrounds = nrounds
)

expected_error <- 0.003070782
expect_equal(bst_colwise$eval_train()[[1L]][["value"]], expected_error)
expect_equal(bst_row_wise$eval_train()[[1L]][["value"]], expected_error)

# check some basic details of the boosters just to be sure force_col_wise
# and force_row_wise are not causing any weird side effects
for (bst in list(bst_row_wise, bst_colwise)){
expect_equal(bst$current_iter(), nrounds)
parsed_model <- jsonlite::fromJSON(bst$dump_model())
expect_equal(parsed_model$objective, "binary sigmoid:1")
expect_false(parsed_model$average_output)
}
})
4 changes: 2 additions & 2 deletions R-package/tests/testthat/test_learning_to_rank.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ test_that("learning-to-rank with lgb.train() works as expected", {
}
expect_identical(sapply(eval_results, function(x) {x$name}), eval_names)
expect_equal(eval_results[[1L]][["value"]], 0.825)
expect_true(abs(eval_results[[2L]][["value"]] - 0.795986) < TOLERANCE)
expect_true(abs(eval_results[[3L]][["value"]] - 0.7734639) < TOLERANCE)
expect_true(abs(eval_results[[2L]][["value"]] - 0.7766434) < TOLERANCE)
expect_true(abs(eval_results[[3L]][["value"]] - 0.7527939) < TOLERANCE)
})

test_that("learning-to-rank with lgb.cv() works as expected", {
Expand Down

0 comments on commit dec3d79

Please sign in to comment.