Skip to content

Commit

Permalink
feat: require unique learner ids in benchmark_grid (#1195)
Browse files Browse the repository at this point in the history
* feat: require unique learner ids in benchmark_grid()

* ...

* ...

* ...

* ...

* ...
  • Loading branch information
be-marc authored Nov 5, 2024
1 parent 2e5267a commit 0c46b69
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 37 deletions.
8 changes: 7 additions & 1 deletion R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ test_matching_task_type = function(task_type, object, class) {
#' @export
#' @param learners (list of [Learner]).
#' @rdname mlr_assertions
assert_learners = function(learners, task = NULL, task_type = NULL, properties = character(), .var.name = vname(learners)) {
assert_learners = function(learners, task = NULL, task_type = NULL, properties = character(), unique_ids = FALSE, .var.name = vname(learners)) {
if (unique_ids) {
ids = map_chr(learners, "id")
if (!test_character(ids, unique = TRUE)) {
stopf("Learners need to have unique IDs: %s", str_collapse(ids))
}
}
invisible(lapply(learners, assert_learner, task = task, task_type = NULL, properties = properties, .var.name = .var.name))
}

Expand Down
4 changes: 0 additions & 4 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
# learner = assert_learner(as_learner(learner, clone = TRUE))
assert_learnable(task, learner)

if (resampling$task_hash != task$hash) {
stopf("Resampling '%s' was not instantiated with task '%s'", resampling$id, task$id)
}

iters = resampling$iters
n_params = max(1L, length(param_values))
# insert constant values
Expand Down
5 changes: 3 additions & 2 deletions R/benchmark_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
#'
benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, paired = FALSE) {
tasks = assert_tasks(as_tasks(tasks))
learners = assert_learners(as_learners(learners))
learners = assert_learners(as_learners(learners), unique_ids = TRUE)
resamplings = assert_resamplings(as_resamplings(resamplings))
if (!is.null(param_values)) {
assert_param_values(param_values, n_learners = length(learners))
Expand Down Expand Up @@ -103,7 +103,8 @@ benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, pai
if (!identical(task_nrow, unique(map_int(resamplings, "task_nrow")))) {
stop("A Resampling is instantiated for a task with a different number of observations")
}
instances = pmap(grid, function(task, resampling) resamplings[[resampling]]$clone())
# clone resamplings for each task and update task hashes
instances = pmap(grid, function(task, resampling) resampling = resamplings[[resampling]]$clone())
} else {
instances = pmap(grid, function(task, resampling) resamplings[[resampling]]$clone()$instantiate(tasks[[task]]))
}
Expand Down
12 changes: 9 additions & 3 deletions R/helper_hashes.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ resampling_task_hashes = function(task, resampling, learner = NULL) {
task_hash = function(task, use_ids, test_ids = NULL, ignore_internal_valid_task = FALSE) {
# order matters: we first check for test_ids and then for the internal_valid_task
internal_valid_task_hash = if (!is.null(test_ids)) {
# this does the same as
# this does the same as
# task$internal_valid_task = test_ids
# $internal_valid_task$hash
# but avoids the deep clone
Expand All @@ -40,6 +40,12 @@ task_hash = function(task, use_ids, test_ids = NULL, ignore_internal_valid_task
task$internal_valid_task$hash
}

calculate_hash(class(task), task$id, task$backend$hash, task$col_info, use_ids, task$col_roles,
get_private(task)$.properties, internal_valid_task_hash)
calculate_hash(
class(task),
task$id,
task$backend$hash,
task$col_info,
use_ids,
get_private(task)$.properties,
internal_valid_task_hash)
}
4 changes: 0 additions & 4 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
resampling = resampling$instantiate(task)
}

if (resampling$task_hash != task$hash) {
stopf("Resampling '%s' was not instantiated with task '%s'", resampling$id, task$id)
}

n = resampling$iters
pb = if (isNamespaceLoaded("progressr")) {
# NB: the progress bar needs to be created in this env
Expand Down
4 changes: 4 additions & 0 deletions man/Resampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_assertions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 6 additions & 9 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ test_that("param_values in benchmark", {


# benchmark grid with multiple params and multiple learners
design = benchmark_grid(tasks, lrns(c("classif.debug", "classif.debug")), rsmp("holdout"), param_values = list(list(list(x = 1), list(x = 0.5)), list()))
design = benchmark_grid(tasks, lrns(c("classif.debug", "classif.rpart")), rsmp("holdout"), param_values = list(list(list(x = 1), list(x = 0.5)), list()))
bmr = benchmark(design)
expect_benchmark_result(bmr)
expect_equal(bmr$n_resample_results, 3)
Expand Down Expand Up @@ -582,14 +582,11 @@ test_that("score works with predictions and empty predictions", {
expect_equal(tab$classif.ce[1], NaN)
})

test_that("resampling was instantiated on the task", {
test_that("benchmark_grid only allows unique learner ids", {
task = tsk("iris")
learner = lrn("classif.rpart")
task = tsk("pima")
resampling = rsmp("cv", folds = 5)
resampling$instantiate(task)
task = tsk("spam")

design = data.table(task = list(task), learner = list(learner), resampling = list(resampling))
resampling = rsmp("holdout")

expect_error(benchmark(design), "not instantiated")
expect_error(benchmark_grid(task, list(learner, learner), resampling), "unique")
})

4 changes: 2 additions & 2 deletions tests/testthat/test_hotstart.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ test_that("learners are hotstarted when benchmark is called", {
resampling = rsmp("cv", folds = 3)
resampling$instantiate(task)

design = benchmark_grid(task, list(learner_1, learner_2), resampling)
design = data.table(task = list(task), learner = list(learner_1, learner_2), resampling = list(resampling))
bmr = benchmark(design, store_models = TRUE)

learners = unlist(map(seq_len(bmr$n_resample_results), function(i) bmr$resample_result(i)$learners))
Expand Down Expand Up @@ -183,7 +183,7 @@ test_that("learners are trained and hotstarted when benchmark is called", {
resampling = rsmp("cv", folds = 3)
resampling$instantiate(task)

design = benchmark_grid(task, list(learner_1, learner_2), resampling)
design = data.table(task = list(task), learner = list(learner_1, learner_2), resampling = list(resampling))
bmr = benchmark(design, store_models = TRUE)

learners = unlist(map(seq_len(bmr$n_resample_results), function(i) bmr$resample_result(i)$learners))
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_parallel.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test_that("parallel benchmark", {
expect_equal(bmr$aggregate(conditions = TRUE)$warnings, 0L)
expect_equal(bmr$aggregate(conditions = TRUE)$errors, 0L)

grid = benchmark_grid(list(tsk("wine"), tsk("sonar")), replicate(2, lrn("classif.debug")), rsmp("cv", folds = 2))
grid = benchmark_grid(list(tsk("wine"), tsk("sonar")), list(lrn("classif.debug", id = "learner_1"), lrn("classif.debug", id = "learner_2")), rsmp("cv", folds = 2))
njobs = 3L
bmr = with_future(future::multisession, {
benchmark(grid, store_models = TRUE)
Expand Down
10 changes: 1 addition & 9 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ test_that("can even use internal_valid predict set on learners that don't suppor
task = tsk("mtcars")
task$internal_valid_task = 1:10
rr = resample(task, lrn("regr.debug", predict_sets = "internal_valid"), rsmp("holdout"))
expect_warning(rr$score(), "only predicted on sets")
})

test_that("callr during prediction triggers marshaling", {
Expand Down Expand Up @@ -511,12 +512,3 @@ test_that("predict_time is 0 if no predict_set is specified", {
expect_true(all(times == 0))
})

test_that("resampling was instantiated on the task", {
learner = lrn("classif.rpart")
task = tsk("pima")
resampling = rsmp("cv", folds = 5)
resampling$instantiate(task)
task = tsk("spam")

expect_error(resample(task, learner, resampling), "not instantiated")
})
5 changes: 3 additions & 2 deletions tests/testthat/test_resultdata.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ test_that("results are ordered", {

test_that("mlr3tuning use case", {
task = tsk("iris")
learners = lrns(c("classif.rpart", "classif.rpart", "classif.rpart"))
learners = replicate(3, lrn("classif.rpart"), simplify = FALSE)
learners[[1]]$param_set$values = list(xval = 0, cp = 0.1)
learners[[2]]$param_set$values = list(xval = 0, cp = 0.2)
learners[[3]]$param_set$values = list(xval = 0, cp = 0.3)
resampling = rsmp("holdout")
resampling$instantiate(task)

bmr = benchmark(benchmark_grid(task, learners, resampling))
bmr = benchmark(data.table(task = list(task), learner = learners, resampling = list(resampling)))

rdata = get_private(bmr)$.data

Expand Down

0 comments on commit 0c46b69

Please sign in to comment.