Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value Change Callback #237

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,9 @@ ParamSet = R6Class("ParamSet",
return(private$.values)
}

xs = Reduce(function(val, fun) fun(val), self$callbacks, xs)
self$assert(xs)

if (length(xs) == 0L) xs = named_list()
private$.values = xs
},
Expand All @@ -390,6 +392,14 @@ ParamSet = R6Class("ParamSet",

extra_values = function() {
private$.values[names(private$.values) %nin% names(private$.params)]
},

callbacks = function(val) {
if (!missing(val)) {
assert_list(val, types = "function", any.missing = FALSE)
private$.callbacks = val
}
private$.callbacks
}
),

Expand All @@ -399,6 +409,7 @@ ParamSet = R6Class("ParamSet",
.params = NULL,
.values = named_list(),
.deps = data.table(id = character(0L), on = character(0L), cond = list()),
.callbacks = list(),
# return a slot / AB, as a named vec, named with id (and can enforce a certain vec-type)
get_member_with_idnames = function(member, astype) set_names(astype(map(self$params, member)), names(self$params)),

Expand Down
8 changes: 4 additions & 4 deletions R/ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,11 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
},

values = function(xs) {
sets = private$.sets
names(sets) = map_chr(sets, "set_id")
if (!missing(xs)) {
assert_list(xs)
xs = Reduce(function(val, fun) fun(val), self$callbacks, xs)
self$assert(xs) # make sure everything is valid and feasible

for (s in sets) {
for (s in private$.sets) {
# retrieve sublist for each set, then assign it in set (after removing prefix)
psids = names(s$params)
if (s$set_id != "") {
Expand All @@ -134,6 +132,8 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
s$values = pv
}
}
sets = private$.sets
names(sets) = map_chr(sets, "set_id")
vals = map(sets, "values")
vals = unlist(vals, recursive = FALSE)
if (!length(vals)) vals = named_list()
Expand Down
127 changes: 127 additions & 0 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,130 @@ test_that("required tag, empty param set (#219)", {
ps$ids()
expect_identical(ps$ids(tags = "required"), character(0))
})

test_that("callbacks", {
ps = ParamSet$new(list(
ParamDbl$new(id = "x", lower = 1, tags = c("t1")),
ParamInt$new(id = "y", lower = 1, upper = 2),
ParamFct$new(id = "z", levels = letters[1:3], tags = c("t1"))
))
ps$callbacks[[1]] = function(x) {
x$x = 2
x
}
expect_equal(ps$values, named_list())
ps$values$y = 1
expect_equal(ps$values, list(y = 1, x = 2))
ps$values$x = 1
expect_equal(ps$values, list(y = 1, x = 2))
ps$callbacks[[2]] = function(x) {
x$x = 1
x
}
ps$values$y = 1
expect_equal(ps$values, list(y = 1, x = 1))
ps$callbacks[[2]] = function(x) {
x$x = 0
x
}
expect_error({ps$values = list(y = 1)}, "is not >= 1")

ps$callbacks[[2]] = function(x) {
x$x = 1
x
}
ps$callbacks[[1]] = function(x) {
x$x = 0
x
}
ps$values = list(y = 2)
expect_equal(ps$values, list(y = 2, x = 1))
ps$callbacks[[1]] = NULL
ps$values = list(y = 1, x = 2)
expect_equal(ps$values, list(y = 1, x = 1))
})

test_that("callbacks on ParamSetCollection", {

psetset = function() {
ps = ParamSet$new(list(ParamUty$new("paramset", custom_check = function(x) check_class(x, "ParamSet", null.ok = TRUE))))
psc = ParamSetCollection$new(list(ps))

psc$callbacks[[1]] = function(x) {
prevset = psc$values$paramset
newset = x$paramset
if (!identical(x$paramset, prevset)) {
psc$params$paramset$assert(newset)
if (!is.null(newset)) {
xcpy = x
xcpy$paramset = NULL
newset$assert(xcpy)
} else {
ParamSet$new()$assert(x)
}
psc$remove_sets("")
psc$add(ps)
if (!is.null(newset)) {
psc$add(newset)
}
}
x
}
psc
}

ps = psetset()

ps1 = ParamSet$new(list(
ParamDbl$new(id = "x", lower = 1, tags = c("t1")),
ParamInt$new(id = "y", lower = 1, upper = 2),
ParamFct$new(id = "z", levels = letters[1:3], tags = c("t1"))
))

ps2 = ParamDbl$new("a")$rep(3)

expect_equal(names(ps$params), "paramset")

ps$values$paramset = ps1

expect_equal(names(ps$params), c("paramset", "x", "y", "z"))

ps$values$x = 1
expect_equal(ps$values, list(paramset = ps1, x = 1))

ps$values = list(paramset = ps2, a_rep_1 = 0)

expect_equal(ps$values, list(paramset = ps2, a_rep_1 = 0))

# The problem here is that there is an ambiguity. suppose
# > psB$values = list(x = 2)
# > ps$values = list(paramset = psA, x = 1)
# Now the command
# (A) > ps$values = list(paramset = psB, x = 1)
# and the command
# (B) > ps$values$paramset = psB
# are functionally the same, but in case (B) we wished we could
# keep the parameter values of psB. However, because (A) is done
# by things like tuning, it takes precedent and must work as
# expected. Therefore the following throws an error.
expect_error({ps$values$paramset = ps1}, "a_rep_1.* not available")

expect_equal(ps$values, list(paramset = ps2, a_rep_1 = 0))

ps$values = c(list(paramset = ps1), ps1$values)

expect_equal(ps$values, list(paramset = ps1, x = 1))

expect_error({ps$values = list(x = 2)}, "Parameter 'x' not available")

expect_equal(ps$values, list(paramset = ps1, x = 1))

expect_error({ps$values$paramset = NULL}, "Parameter 'x' not available")

expect_equal(ps$values, list(paramset = ps1, x = 1))

ps$values = list()

expect_equal(ps$values, named_list())

})