Skip to content

Commit

Permalink
mlr3 upkeep
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jul 18, 2024
1 parent 78e3976 commit 7fa16f2
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 22 deletions.
8 changes: 1 addition & 7 deletions R/ResamplingNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,14 @@ ResamplingNestedCV = R6::R6Class("ResamplingNestedCV",
assert_ro_binding(rhs)
pv = self$param_set$get_values()
pv$repeats * pv$folds^2
},
#' @field primary_iters (`integer()`)\cr
#' The primary iterations to be used for point estimation.
primary_iters = function(rhs) {
assert_ro_binding(rhs)
pvs = self$param_set$get_values()
as.vector(outer(seq_len(pvs$folds), pvs$folds^2 * seq(0, pvs$repeats - 1), `+`))
}
),
private = list(
.sample = function(ids, ...) {
pv = self$param_set$get_values()
folds = pv$folds
repeats = pv$repeats
self$primary_iters = as.vector(outer(seq_len(pv$folds), pv$folds^2 * seq(0, pv$repeats - 1), `+`))
map_dtr(seq(repeats), function(r) {
data.table(
row_id = ids,
Expand Down
12 changes: 3 additions & 9 deletions R/ResamplingPairedSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,13 @@ ResamplingPairedSubsampling = R6Class("ResamplingPairedSubsampling",
}

pvs = self$param_set$get_values()

repeats_in = pvs$repeats_in
repeats_out = pvs$repeats_out
ratio = pvs$ratio

self$primary_iters = repeats_in

n = length(ids)
n1 = round(n * ratio)
n2 = n - n1
Expand All @@ -117,7 +120,6 @@ ResamplingPairedSubsampling = R6Class("ResamplingPairedSubsampling",
stopf("Not enough observations in the task")
}
}


instance = vector("list", length = 1 + repeats_out * 2)
instance[[1]] = private$.sample_once(ids, repeats_in, ratio)
Expand All @@ -131,7 +133,6 @@ ResamplingPairedSubsampling = R6Class("ResamplingPairedSubsampling",

new_ratio = (n_sub - n2) / n_sub


for (i in seq_len(repeats_out)) {
sub_ids = sample(ids, n_sub * 2)
instance[[i * 2]] = private$.sample_once(sub_ids[seq(1, n_sub)], repeats_in, new_ratio)
Expand Down Expand Up @@ -163,13 +164,6 @@ ResamplingPairedSubsampling = R6Class("ResamplingPairedSubsampling",
iters = function(rhs) {
pvs = self$param_set$get_values()
(pvs$repeats_out * 2 + 1) * pvs$repeats_in
},
#' @field primary_iters (`integer()`)\cr
#' The primary iterations to be used for point estimation.
primary_iters = function(rhs) {
assert_ro_binding(rhs)
pvs = self$param_set$get_values()
pvs$repeats_in
}
)
)
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ register_mlr3 = function(...) {
# static checker
future::plan
withr::with_seed
mlr3measures::se

if (Sys.getenv("IN_PKGDOWN") == "true") {
lg$set_threshold("warn")
Expand Down
3 changes: 0 additions & 3 deletions man/mlr_resamplings_ncv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/mlr_resamplings_paired_subsampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tests/testthat/test_ResamplingNestedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ test_that("primary iters", {
r$param_set$set_values(
folds = 4L, repeats = 1
)
r$instantiate(task)
expect_equal(r$primary_iters, 1:4)
r$param_set$set_values(
folds = 4L, repeats = 2
)
r$instantiate(task)
expect_equal(r$primary_iters, c(1:4, 17:20))
})
2 changes: 2 additions & 0 deletions tests/testthat/test_ResamplingPairedSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ test_that("primary_iters", {
r$instantiate(task)
expect_equal(r$primary_iters, 1L)
r$param_set$values$repeats_in = 2
r$instantiate(task)
expect_equal(r$primary_iters, 2L)
r$instantiate(task)
r$param_set$values$repeats_out = 2L
expect_equal(r$primary_iters, 2L)
})

0 comments on commit 7fa16f2

Please sign in to comment.