diff --git a/NEWS.md b/NEWS.md index 2cc4948b..e5df0850 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # tune (development version) +* Iterative searches now cache preprocessor fits that would be duplicated across + iterations, greatly reducing the time to tune computationally intensive + preprocessing steps in `tune_bayes()` (and `tune_sim_anneal()`) in + finetune (#955). + * The package will now warn when parallel processing has been enabled with foreach but not with future. See [`?parallelism`](https://tune.tidymodels.org/dev/reference/parallelism.html) to learn more about transitioning your code to future (#878, #866). * The package will now log a backtrace for errors and warnings that occur during tuning. When a tuning process encounters issues, see the new `trace` column in the `collect_notes(.Last.tune.result)` output to find precisely where the error occurred (#873). diff --git a/R/cache.R b/R/cache.R new file mode 100644 index 00000000..37f223a6 --- /dev/null +++ b/R/cache.R @@ -0,0 +1,32 @@ +# For iterative searches, the same preprocessor fits are recomputed +# for every iteration. Hook into the `tune_env`, an environment that +# defines a call to a tuning function, to cache repeated fits while tuning. +has_cached_result <- function(split_id, param_desc) { + cache <- cached_results() + + if (!split_id %in% names(cache) || !param_desc %in% names(cache[[split_id]])) { + return(FALSE) + } + + TRUE +} + +get_cached_result <- function(split_id, param_desc) { + cache <- cached_results() + cache[[split_id]][[param_desc]] +} + +set_cached_result <- function(split_id, param_desc, workflow) { + cache <- cached_results() + cache[[split_id]][[param_desc]] <- workflow + workflow +} + +cached_results <- function() { + env <- tune_env$progress_env + if (!"cache" %in% names(env)) { + rlang::env_bind(env, cache = rlang::new_environment()) + } + + env$cache +} diff --git a/R/grid_code_paths.R b/R/grid_code_paths.R index 505d6ccc..7082fe45 100644 --- a/R/grid_code_paths.R +++ b/R/grid_code_paths.R @@ -436,13 +436,19 @@ tune_grid_loop_iter <- function(split, grid_preprocessor = iter_grid_preprocessor ) - workflow <- .catch_and_log( - .expr = .fit_pre(workflow, analysis), - control, - split_labels, - iter_msg_preprocessor, - notes = out_notes - ) + if (!has_cached_result(split_labels$id, iter_msg_preprocessor)) { + workflow <- .catch_and_log( + .expr = .fit_pre(workflow, analysis), + control, + split_labels, + iter_msg_preprocessor, + notes = out_notes + ) + + set_cached_result(split_labels$id, iter_msg_preprocessor, workflow) + } else { + workflow <- get_cached_result(split_labels$id, iter_msg_preprocessor) + } if (is_failure(workflow)) { next