Skip to content

Commit

Permalink
feat: add function to extract information from a model
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 20, 2024
1 parent 4b1f072 commit 44b4496
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 7 deletions.
16 changes: 14 additions & 2 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#' @template param_allow_hotstart
#' @template param_clone
#' @template param_unmarshal
#' @template param_extractor
#' @return [ResampleResult].
#'
#' @template section_predict_sets
Expand Down Expand Up @@ -55,7 +56,18 @@
#' bmr1 = as_benchmark_result(rr)
#' bmr2 = as_benchmark_result(rr_featureless)
#' print(bmr1$combine(bmr2))
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) {
resample = function(
task,
learner,
resampling,
store_models = FALSE,
store_backends = TRUE,
encapsulate = NA_character_,
allow_hotstart = FALSE,
clone = c("task", "learner", "resampling"),
unmarshal = TRUE,
extractor = NULL
) {
assert_subset(clone, c("task", "learner", "resampling"))
task = assert_task(as_task(task, clone = "task" %in% clone))
learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE))
Expand Down Expand Up @@ -115,7 +127,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
}

res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode,
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal)
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, extractor = extractor)
)

data = data.table(
Expand Down
19 changes: 18 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,20 @@ learner_predict = function(learner, task, row_ids = NULL) {
}


workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE, unmarshal = TRUE) {
workhorse = function(
iteration,
task,
learner,
resampling,
param_values = NULL,
lgr_threshold,
store_models = FALSE,
pb = NULL,
mode = "train",
is_sequential = TRUE,
unmarshal = TRUE,
extractor = NULL
) {
if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}
Expand Down Expand Up @@ -332,6 +345,10 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL,
}
pdatas = discard(pdatas, is.null)

if (!is.null(extractor)) {
learner$state = insert_named(learner$state, extractor(learner$model))
}

# set the model slot after prediction so it can be sent back to the main process
process_model_after_predict(
learner = learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null,
Expand Down
3 changes: 3 additions & 0 deletions man-roxygen/param_extractor.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#' @param extractor (`function()`)\cr
#' Function to extract information from the learner model on the worker.
#' The function takes `model` as input and must return a named list.
6 changes: 3 additions & 3 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/resample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 44b4496

Please sign in to comment.