From 55e75e5a18e21451de1d1515df1c9f6aa3e88d12 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 Nov 2023 07:12:00 -0500 Subject: [PATCH 1/6] initial code for single selections --- NAMESPACE | 1 + R/metric-selection.R | 81 +++++++ .../_snaps/eval-time-single-selection.md | 84 +++++++ .../test-eval-time-single-selection.R | 219 ++++++++++++++++++ 4 files changed, 385 insertions(+) create mode 100644 R/metric-selection.R create mode 100644 tests/testthat/_snaps/eval-time-single-selection.md create mode 100644 tests/testthat/test-eval-time-single-selection.R diff --git a/NAMESPACE b/NAMESPACE index 0baf6b108..5d92b47d9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -223,6 +223,7 @@ export(required_pkgs) export(select_best) export(select_by_one_std_err) export(select_by_pct_loss) +export(select_eval_time) export(show_best) export(show_notes) export(tunable) diff --git a/R/metric-selection.R b/R/metric-selection.R new file mode 100644 index 000000000..bfae5fb9b --- /dev/null +++ b/R/metric-selection.R @@ -0,0 +1,81 @@ +# For iterative search and racing, what metric will be optimized? +first_metric <- function(mtr_set) { + tibble::as_tibble(mtr_set)[1,] +} + +# Did the user pass an improper metric (i.e. want rmse but not computed)? +check_chosen_metric <- function(metric, mtr_set) { + mtr_info <- tibble::as_tibble(mtr_set) + in_set <- any(mtr_info$metric == metric) + if (!in_set) { + cli::cli_abort("metric '{metric}' is not in the metric set.") + } + invisible(TRUE) +} + +# Return the validated evaluation times. May subset to a single value if needed. +#' @export +select_eval_time <- function(mtr_set, eval_time = NULL, single = FALSE) { + + only_stc <- function(x) all(x$class == "static_survival_metric") + only_dyn <- function(x) all(x$class == "dynamic_survival_metric") + only_int <- function(x) all(x$class == "integrated_survival_metric") + has_int <- function(x) any(x$class == "integrated_survival_metric") + has_dyn <- function(x) any(x$class == "dynamic_survival_metric") + + if (single) { + # This means that we will be using the metric results to rank or + # optimize something. We need one eval time when the (single) metric + # is dynamic; null otherwise + mtr_info <- first_metric(mtr_set) + } else { + # In this case, we will use the results for autoplot(), int_pct(), or + # augment(). We need a valid set of evaluation times which could be none + # one, or more than one depending on the metric + mtr_info <- tibble::as_tibble(mtr_set) + } + mtr_first <- first_metric(mtr_set) + + if (!any(grepl("_survival_", mtr_info$class))) { + return(NULL) + } + + # ------------------------------------------------------------------------------ + # check the size of the eval times + + num_times <- length(eval_time) + + if (only_stc(mtr_info) & num_times != 0) { + cli::cli_warn("Evaluation times are only required for dynmanic or integrated metrics.") + eval_time <- NULL + } + + if (only_dyn(mtr_info) & num_times == 0) { + cli::cli_abort("A single evaluation time is required; please choose one.") + } + + # this requires all metrics + if ( has_int( tibble::as_tibble(mtr_set) ) & num_times < 2 ) { + cli::cli_abort("2+ evaluation times are required.") + } + + # checks for cases where only a single eval time should be returned + if (single) { + + if ( only_dyn(mtr_info) & num_times > 1 ) { + eval_time <- eval_time[1] + print_time <- format(eval_time, digits = 3) + cli::cli_warn("{num_times} evaluation times were selected; the first ({print_time}) will be used.") + } + + if (only_int(mtr_info)) { + eval_time <- NULL + } + } else { + # cases where we maybe need evaluation time and return them all + + } + + + eval_time +} diff --git a/tests/testthat/_snaps/eval-time-single-selection.md b/tests/testthat/_snaps/eval-time-single-selection.md new file mode 100644 index 000000000..f14c47926 --- /dev/null +++ b/tests/testthat/_snaps/eval-time-single-selection.md @@ -0,0 +1,84 @@ +# selecting single eval time - pure metric sets + + Evaluation times are only required for dynmanic or integrated metrics. + +--- + + Evaluation times are only required for dynmanic or integrated metrics. + +--- + + Code + select_eval_time(met_dyn, eval_time = NULL, single = TRUE) + Condition + Error in `select_eval_time()`: + ! A single evaluation time is required; please choose one. + +--- + + 2 evaluation times were selected; the first (0.714) will be used. + +--- + + Code + select_eval_time(met_int, eval_time = NULL, single = TRUE) + Condition + Error in `select_eval_time()`: + ! 2+ evaluation times are required. + +--- + + Code + select_eval_time(met_int, eval_time = times_1, single = TRUE) + Condition + Error in `select_eval_time()`: + ! 2+ evaluation times are required. + +# selecting single eval time - mixed metric sets - dynamic first + + Code + select_eval_time(met_mix_dyn, eval_time = NULL, single = TRUE) + Condition + Error in `select_eval_time()`: + ! A single evaluation time is required; please choose one. + +--- + + Code + select_eval_time(met_mix_dyn_all, eval_time = NULL, single = TRUE) + Condition + Error in `select_eval_time()`: + ! A single evaluation time is required; please choose one. + +# selecting single eval time - mixed metric sets - integrated first + + Code + select_eval_time(met_mix_int, eval_time = NULL, single = TRUE) + Condition + Error in `select_eval_time()`: + ! 2+ evaluation times are required. + +--- + + Code + select_eval_time(met_mix_int, eval_time = times_1, single = TRUE) + Condition + Error in `select_eval_time()`: + ! 2+ evaluation times are required. + +--- + + Code + select_eval_time(met_mix_int_all, eval_time = NULL, single = TRUE) + Condition + Error in `select_eval_time()`: + ! 2+ evaluation times are required. + +--- + + Code + select_eval_time(met_mix_int_all, eval_time = times_1, single = TRUE) + Condition + Error in `select_eval_time()`: + ! 2+ evaluation times are required. + diff --git a/tests/testthat/test-eval-time-single-selection.R b/tests/testthat/test-eval-time-single-selection.R new file mode 100644 index 000000000..d1382f0d7 --- /dev/null +++ b/tests/testthat/test-eval-time-single-selection.R @@ -0,0 +1,219 @@ + +test_that("selecting single eval time - non-survival case", { + library(yardstick) + + met_reg <- metric_set(rmse) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # ---------------------------------------------------------------------------- + # eval time is not applicable outside of survival models; return null + + expect_null(select_eval_time(met_reg, eval_time = NULL, single = TRUE)) + expect_null(select_eval_time(met_reg, eval_time = times_1, single = TRUE)) + expect_null(select_eval_time(met_reg, eval_time = times_2, single = TRUE)) + +}) + +test_that("selecting single eval time - pure metric sets", { + library(yardstick) + + met_int <- metric_set(brier_survival_integrated) + met_dyn <- metric_set(brier_survival) + met_stc <- metric_set(concordance_survival) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # ---------------------------------------------------------------------------- + # all static; return NULL and add warning if times are given + + expect_null(select_eval_time(met_stc, eval_time = NULL, single = TRUE)) + + expect_snapshot_warning( + stc_one <- select_eval_time(met_stc, eval_time = times_1, single = TRUE) + ) + expect_null(stc_one) + + expect_snapshot_warning( + stc_multi <- select_eval_time(met_stc, eval_time = times_2, single = TRUE) + ) + expect_null(stc_multi) + + # ---------------------------------------------------------------------------- + # all dynamic; return a single time and warn if there are more or zero + + expect_snapshot( + select_eval_time(met_dyn, eval_time = NULL, single = TRUE), + error = TRUE + ) + expect_equal( + select_eval_time(met_dyn, eval_time = times_1, single = TRUE), + times_1 + ) + expect_snapshot_warning( + dyn_multi <- select_eval_time(met_dyn, eval_time = times_2, single = TRUE) + ) + expect_equal(dyn_multi, times_2[1]) + + # ---------------------------------------------------------------------------- + # all integrated; return NULL and error if there < 2 + + expect_snapshot( + select_eval_time(met_int, eval_time = NULL, single = TRUE), + error = TRUE + ) + expect_snapshot( + select_eval_time(met_int, eval_time = times_1, single = TRUE), + error = TRUE + ) + + expect_silent( + int_1 <- select_eval_time(met_int, eval_time = times_2, single = TRUE) + ) + expect_null(int_1) + + +}) + +test_that("selecting single eval time - mixed metric sets - static first", { + library(yardstick) + + met_mix_stc <- metric_set(concordance_survival, brier_survival) + met_mix_stc_all <- metric_set(concordance_survival, brier_survival, brier_survival_integrated) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # ---------------------------------------------------------------------------- + # static is first but includes dynamic. Should return NULL and add warning + # if times are given + + expect_null( + select_eval_time(met_mix_stc, eval_time = NULL, single = TRUE) + ) + # TODO should not warn + expect_warning( + select_eval_time(met_mix_stc, eval_time = times_1, single = TRUE) + ) + expect_warning( + select_eval_time(met_mix_stc, eval_time = times_2, single = TRUE) + ) + + # ---------------------------------------------------------------------------- + # static is first but includes dynamic and integrated. Should return NULL and add warning + # if times are given + + # TODO errors but should not since first is static; should not warn + + # expect_null( + # select_eval_time(met_mix_stc_all, eval_time = NULL, single = TRUE) + # ) + # expect_warning( + # select_eval_time(met_mix_stc_all, eval_time = times_1, single = TRUE) + # ) + # expect_warning( + # select_eval_time(met_mix_stc_all, eval_time = times_2, single = TRUE) + # ) + + +}) + +test_that("selecting single eval time - mixed metric sets - dynamic first", { + library(yardstick) + + met_mix_dyn <- metric_set(brier_survival, concordance_survival) + met_mix_dyn_all <- + metric_set(brier_survival, + brier_survival_integrated, + concordance_survival) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # ---------------------------------------------------------------------------- + # dynamic is first but includes static Should return NULL and add warning + # if times are given + + expect_snapshot( + select_eval_time(met_mix_dyn, eval_time = NULL, single = TRUE), + error = TRUE + ) + # TODO should not warn + expect_equal( + select_eval_time(met_mix_dyn, eval_time = times_1, single = TRUE), + times_1 + ) + expect_warning( + dyn_multi <- select_eval_time(met_mix_dyn, eval_time = times_2, single = TRUE) + ) + expect_equal(dyn_multi, times_2[1]) + + # ---------------------------------------------------------------------------- + # dynamic is first but includes static and integrated. Should return NULL and add warning + # if times are given + + expect_snapshot( + select_eval_time(met_mix_dyn_all, eval_time = NULL, single = TRUE), + error = TRUE + ) + # TODO errors but should not + # expect_warning( + # select_eval_time(met_mix_dyn_all, eval_time = times_1, single = TRUE) + # ) + expect_warning( + dyn_multi <- select_eval_time(met_mix_dyn_all, eval_time = times_2, single = TRUE) + ) + expect_equal(dyn_multi, times_2[1]) + +}) + + +test_that("selecting single eval time - mixed metric sets - integrated first", { + library(yardstick) + + met_mix_int <- metric_set(brier_survival_integrated, concordance_survival) + met_mix_int_all <- + metric_set(brier_survival_integrated, + brier_survival, + concordance_survival) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # ---------------------------------------------------------------------------- + # integrated is first but includes static. Should return NULL and add error + # if <2 times are given + + expect_snapshot( + select_eval_time(met_mix_int, eval_time = NULL, single = TRUE), + error = TRUE + ) + expect_snapshot( + select_eval_time(met_mix_int, eval_time = times_1, single = TRUE), + error = TRUE + ) + expect_silent( + int_multi <- select_eval_time(met_mix_int, eval_time = times_2, single = TRUE) + ) + expect_null(int_multi) + + # ---------------------------------------------------------------------------- + # integrated is first but includes static and dynamic. Should return NULL and + # add error if <2 times are given + + expect_snapshot( + select_eval_time(met_mix_int_all, eval_time = NULL, single = TRUE), + error = TRUE + ) + expect_snapshot( + select_eval_time(met_mix_int_all, eval_time = times_1, single = TRUE), + error = TRUE + ) + expect_silent( + int_multi <- select_eval_time(met_mix_int_all, eval_time = times_2, single = TRUE) + ) + expect_null(int_multi) + +}) From 9db7ec13fecab77d037c349ed872565fb3542e63 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 Nov 2023 07:25:19 -0500 Subject: [PATCH 2/6] added temp notes --- R/select_best.R | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/R/select_best.R b/R/select_best.R index a9771b1e6..85305667d 100644 --- a/R/select_best.R +++ b/R/select_best.R @@ -76,6 +76,8 @@ show_best.default <- function(x, ...) { #' @export #' @rdname show_best show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, ...) { + # TODO should return the as_tibble(metric_set) results to get the class etc. + # TODO new function start metric <- choose_metric(metric, x) dots <- rlang::enquos(...) @@ -92,8 +94,12 @@ show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, .. metric <- metrics } + # TODO new function stop + # get estimates/summarise summary_res <- summary_res %>% dplyr::filter(.metric == metric) + + # TODO split selecting the req time and seeing if it is in the data summary_res <- choose_eval_time(summary_res, x, eval_time) if (nrow(summary_res) == 0) { @@ -349,7 +355,8 @@ middle_eval_time <- function(x) { eval_time } - +# NOTE this chooses the time and subsets the data; break it up to only select +# time choose_eval_time <- function(x, object, eval_time) { mtrs <- .get_tune_metrics(object) mtrs <- tibble::as_tibble(mtrs) From a364dbe18d6310e878acecd93e85447154cbdcb2 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 Nov 2023 16:28:58 -0500 Subject: [PATCH 3/6] refactor for single time point and reset unit tests --- NAMESPACE | 5 +- R/metric-selection.R | 80 ++++----- .../_snaps/eval-time-single-selection.md | 70 ++------ .../test-eval-time-single-selection.R | 156 ++++++++++-------- 4 files changed, 135 insertions(+), 176 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 5d92b47d9..91431de6e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -149,6 +149,8 @@ export(.stash_last_result) export(.use_case_weights_with_yardstick) export(augment) export(autoplot) +export(check_chosen_metric) +export(check_eval_time) export(check_initial) export(check_metrics) export(check_parameters) @@ -186,6 +188,8 @@ export(finalize_model) export(finalize_recipe) export(finalize_workflow) export(finalize_workflow_preprocessor) +export(first_eval_time) +export(first_metric) export(fit_best) export(fit_max_value) export(fit_resamples) @@ -223,7 +227,6 @@ export(required_pkgs) export(select_best) export(select_by_one_std_err) export(select_by_pct_loss) -export(select_eval_time) export(show_best) export(show_notes) export(tunable) diff --git a/R/metric-selection.R b/R/metric-selection.R index bfae5fb9b..207249b29 100644 --- a/R/metric-selection.R +++ b/R/metric-selection.R @@ -1,9 +1,13 @@ # For iterative search and racing, what metric will be optimized? +#' @keywords internal +#' @export first_metric <- function(mtr_set) { tibble::as_tibble(mtr_set)[1,] } # Did the user pass an improper metric (i.e. want rmse but not computed)? +#' @keywords internal +#' @export check_chosen_metric <- function(metric, mtr_set) { mtr_info <- tibble::as_tibble(mtr_set) in_set <- any(mtr_info$metric == metric) @@ -13,69 +17,51 @@ check_chosen_metric <- function(metric, mtr_set) { invisible(TRUE) } -# Return the validated evaluation times. May subset to a single value if needed. +#' @keywords internal #' @export -select_eval_time <- function(mtr_set, eval_time = NULL, single = FALSE) { - - only_stc <- function(x) all(x$class == "static_survival_metric") - only_dyn <- function(x) all(x$class == "dynamic_survival_metric") - only_int <- function(x) all(x$class == "integrated_survival_metric") - has_int <- function(x) any(x$class == "integrated_survival_metric") - has_dyn <- function(x) any(x$class == "dynamic_survival_metric") +first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) { + num_times <- length(eval_time) - if (single) { - # This means that we will be using the metric results to rank or - # optimize something. We need one eval time when the (single) metric - # is dynamic; null otherwise + if (is.null(metric)) { mtr_info <- first_metric(mtr_set) + metric <- mtr_info$metric } else { - # In this case, we will use the results for autoplot(), int_pct(), or - # augment(). We need a valid set of evaluation times which could be none - # one, or more than one depending on the metric mtr_info <- tibble::as_tibble(mtr_set) + mtr_info <- mtr_info[mtr_info$metric == metric,] } - mtr_first <- first_metric(mtr_set) + # Not a survival metric if (!any(grepl("_survival_", mtr_info$class))) { return(NULL) } - # ------------------------------------------------------------------------------ - # check the size of the eval times - - num_times <- length(eval_time) - - if (only_stc(mtr_info) & num_times != 0) { - cli::cli_warn("Evaluation times are only required for dynmanic or integrated metrics.") - eval_time <- NULL - } - - if (only_dyn(mtr_info) & num_times == 0) { - cli::cli_abort("A single evaluation time is required; please choose one.") + # Not a metric that requires an eval_time + no_time_req <- c("static_survival_metric", "integrated_survival_metric") + if (mtr_info$class %in% no_time_req) { + if (num_times > 0) { + cli::cli_warn("Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric.") + } + return(NULL) } - # this requires all metrics - if ( has_int( tibble::as_tibble(mtr_set) ) & num_times < 2 ) { - cli::cli_abort("2+ evaluation times are required.") + # checks for dynamic metrics + if (num_times == 0) { + cli::cli_abort("A single evaluation time is required to use this metric.") + } else if ( num_times > 1 ) { + eval_time <- eval_time[1] + print_time <- format(eval_time, digits = 3) + cli::cli_warn("{num_times} evaluation times were available; the first ({print_time}) will be used.") } - # checks for cases where only a single eval time should be returned - if (single) { - - if ( only_dyn(mtr_info) & num_times > 1 ) { - eval_time <- eval_time[1] - print_time <- format(eval_time, digits = 3) - cli::cli_warn("{num_times} evaluation times were selected; the first ({print_time}) will be used.") - } - - if (only_int(mtr_info)) { - eval_time <- NULL - } - } else { - # cases where we maybe need evaluation time and return them all + eval_time +} +#' @keywords internal +#' @export +check_eval_time <- function(eval_time = NULL, all_times = NULL) { + if (!is.null(eval_time)) { + return(eval_time) } - - eval_time + all_times <- sort(unique(all_times)) } diff --git a/tests/testthat/_snaps/eval-time-single-selection.md b/tests/testthat/_snaps/eval-time-single-selection.md index f14c47926..0a1829035 100644 --- a/tests/testthat/_snaps/eval-time-single-selection.md +++ b/tests/testthat/_snaps/eval-time-single-selection.md @@ -1,84 +1,44 @@ # selecting single eval time - pure metric sets - Evaluation times are only required for dynmanic or integrated metrics. + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. --- - Evaluation times are only required for dynmanic or integrated metrics. + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. --- Code - select_eval_time(met_dyn, eval_time = NULL, single = TRUE) + first_eval_time(met_dyn, eval_time = NULL) Condition - Error in `select_eval_time()`: - ! A single evaluation time is required; please choose one. - ---- - - 2 evaluation times were selected; the first (0.714) will be used. + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. --- Code - select_eval_time(met_int, eval_time = NULL, single = TRUE) + first_eval_time(met_dyn, "brier_survival", eval_time = NULL) Condition - Error in `select_eval_time()`: - ! 2+ evaluation times are required. + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. --- - Code - select_eval_time(met_int, eval_time = times_1, single = TRUE) - Condition - Error in `select_eval_time()`: - ! 2+ evaluation times are required. + 2 evaluation times were available; the first (0.714) will be used. # selecting single eval time - mixed metric sets - dynamic first Code - select_eval_time(met_mix_dyn, eval_time = NULL, single = TRUE) - Condition - Error in `select_eval_time()`: - ! A single evaluation time is required; please choose one. - ---- - - Code - select_eval_time(met_mix_dyn_all, eval_time = NULL, single = TRUE) - Condition - Error in `select_eval_time()`: - ! A single evaluation time is required; please choose one. - -# selecting single eval time - mixed metric sets - integrated first - - Code - select_eval_time(met_mix_int, eval_time = NULL, single = TRUE) - Condition - Error in `select_eval_time()`: - ! 2+ evaluation times are required. - ---- - - Code - select_eval_time(met_mix_int, eval_time = times_1, single = TRUE) - Condition - Error in `select_eval_time()`: - ! 2+ evaluation times are required. - ---- - - Code - select_eval_time(met_mix_int_all, eval_time = NULL, single = TRUE) + first_eval_time(met_mix_dyn, eval_time = NULL) Condition - Error in `select_eval_time()`: - ! 2+ evaluation times are required. + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. --- Code - select_eval_time(met_mix_int_all, eval_time = times_1, single = TRUE) + first_eval_time(met_mix_dyn_all, eval_time = NULL) Condition - Error in `select_eval_time()`: - ! 2+ evaluation times are required. + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. diff --git a/tests/testthat/test-eval-time-single-selection.R b/tests/testthat/test-eval-time-single-selection.R index d1382f0d7..eb2b104fc 100644 --- a/tests/testthat/test-eval-time-single-selection.R +++ b/tests/testthat/test-eval-time-single-selection.R @@ -1,3 +1,10 @@ +# "selecting single eval time" means how functions like `show_best()` will pick +# an evaluation time for a dynamic metric when none is given. Previously we +# would find what is in the data and select a time that was close to the median +# time. This was fine but inconsistent with other parts of tidymodels that do +# similar operations. For example, tune_bayes has to have a metric to optimize +# on so it uses the first metric in the metric set and, if needed, the first +# evaluation time given to the function. test_that("selecting single eval time - non-survival case", { library(yardstick) @@ -10,9 +17,9 @@ test_that("selecting single eval time - non-survival case", { # ---------------------------------------------------------------------------- # eval time is not applicable outside of survival models; return null - expect_null(select_eval_time(met_reg, eval_time = NULL, single = TRUE)) - expect_null(select_eval_time(met_reg, eval_time = times_1, single = TRUE)) - expect_null(select_eval_time(met_reg, eval_time = times_2, single = TRUE)) + expect_null(first_eval_time(met_reg, eval_time = NULL)) + expect_null(first_eval_time(met_reg, eval_time = times_1)) + expect_null(first_eval_time(met_reg, eval_time = times_2)) }) @@ -29,51 +36,59 @@ test_that("selecting single eval time - pure metric sets", { # ---------------------------------------------------------------------------- # all static; return NULL and add warning if times are given - expect_null(select_eval_time(met_stc, eval_time = NULL, single = TRUE)) + expect_null(first_eval_time(met_stc, eval_time = NULL)) + expect_null(first_eval_time(met_stc, "concordance_survival", eval_time = NULL)) expect_snapshot_warning( - stc_one <- select_eval_time(met_stc, eval_time = times_1, single = TRUE) + stc_one <- first_eval_time(met_stc, eval_time = times_1) ) expect_null(stc_one) expect_snapshot_warning( - stc_multi <- select_eval_time(met_stc, eval_time = times_2, single = TRUE) + stc_multi <- first_eval_time(met_stc, eval_time = times_2) ) expect_null(stc_multi) # ---------------------------------------------------------------------------- - # all dynamic; return a single time and warn if there are more or zero + # all dynamic; return a single time and warn if there are more and error if + # there are none expect_snapshot( - select_eval_time(met_dyn, eval_time = NULL, single = TRUE), + first_eval_time(met_dyn, eval_time = NULL), error = TRUE ) + expect_snapshot( + first_eval_time(met_dyn, "brier_survival", eval_time = NULL), + error = TRUE + ) + expect_equal( - select_eval_time(met_dyn, eval_time = times_1, single = TRUE), + first_eval_time(met_dyn, eval_time = times_1), times_1 ) + expect_snapshot_warning( - dyn_multi <- select_eval_time(met_dyn, eval_time = times_2, single = TRUE) + dyn_multi <- first_eval_time(met_dyn, eval_time = times_2) ) expect_equal(dyn_multi, times_2[1]) # ---------------------------------------------------------------------------- - # all integrated; return NULL and error if there < 2 + # all integrated; return NULL and warn if there 1+ times - expect_snapshot( - select_eval_time(met_int, eval_time = NULL, single = TRUE), - error = TRUE - ) - expect_snapshot( - select_eval_time(met_int, eval_time = times_1, single = TRUE), - error = TRUE + expect_null(first_eval_time(met_int, eval_time = NULL)) + expect_null( + first_eval_time(met_int, "brier_survival_integrated", eval_time = NULL) ) - expect_silent( - int_1 <- select_eval_time(met_int, eval_time = times_2, single = TRUE) + expect_warning( + int_1 <- first_eval_time(met_int, eval_time = times_1) ) expect_null(int_1) + expect_warning( + int_multi <- first_eval_time(met_int, eval_time = times_2) + ) + expect_null(int_multi) }) @@ -91,33 +106,36 @@ test_that("selecting single eval time - mixed metric sets - static first", { # if times are given expect_null( - select_eval_time(met_mix_stc, eval_time = NULL, single = TRUE) + first_eval_time(met_mix_stc, eval_time = NULL) ) - # TODO should not warn + expect_warning( - select_eval_time(met_mix_stc, eval_time = times_1, single = TRUE) + stc_1 <- first_eval_time(met_mix_stc, eval_time = times_1) ) + expect_null(stc_1) + expect_warning( - select_eval_time(met_mix_stc, eval_time = times_2, single = TRUE) + stc_multi <- first_eval_time(met_mix_stc, eval_time = times_2) ) + expect_null(stc_multi) # ---------------------------------------------------------------------------- - # static is first but includes dynamic and integrated. Should return NULL and add warning - # if times are given + # static is first but includes dynamic and integrated. Should return NULL and + # add warning if times are given - # TODO errors but should not since first is static; should not warn - - # expect_null( - # select_eval_time(met_mix_stc_all, eval_time = NULL, single = TRUE) - # ) - # expect_warning( - # select_eval_time(met_mix_stc_all, eval_time = times_1, single = TRUE) - # ) - # expect_warning( - # select_eval_time(met_mix_stc_all, eval_time = times_2, single = TRUE) - # ) + expect_null( + first_eval_time(met_mix_stc_all, eval_time = NULL) + ) + expect_warning( + stc_1 <- first_eval_time(met_mix_stc_all, eval_time = times_1) + ) + expect_null(stc_1) + expect_warning( + stc_multi <- first_eval_time(met_mix_stc_all, eval_time = times_2) + ) + expect_null(stc_multi) }) test_that("selecting single eval time - mixed metric sets - dynamic first", { @@ -133,37 +151,36 @@ test_that("selecting single eval time - mixed metric sets - dynamic first", { times_2 <- as.numeric(5:4) / 7 # ---------------------------------------------------------------------------- - # dynamic is first but includes static Should return NULL and add warning - # if times are given + # dynamic is first but includes static. Should return single time and add warning + # if 2+ times are given expect_snapshot( - select_eval_time(met_mix_dyn, eval_time = NULL, single = TRUE), + first_eval_time(met_mix_dyn, eval_time = NULL), error = TRUE ) - # TODO should not warn expect_equal( - select_eval_time(met_mix_dyn, eval_time = times_1, single = TRUE), + first_eval_time(met_mix_dyn, eval_time = times_1), times_1 ) expect_warning( - dyn_multi <- select_eval_time(met_mix_dyn, eval_time = times_2, single = TRUE) + dyn_multi <- first_eval_time(met_mix_dyn, eval_time = times_2) ) expect_equal(dyn_multi, times_2[1]) # ---------------------------------------------------------------------------- - # dynamic is first but includes static and integrated. Should return NULL and add warning - # if times are given + # dynamic is first but includes static and integrated. Should return single + # time and add warning if 2+ times are given expect_snapshot( - select_eval_time(met_mix_dyn_all, eval_time = NULL, single = TRUE), + first_eval_time(met_mix_dyn_all, eval_time = NULL), error = TRUE ) - # TODO errors but should not - # expect_warning( - # select_eval_time(met_mix_dyn_all, eval_time = times_1, single = TRUE) - # ) + expect_equal( + first_eval_time(met_mix_dyn_all, eval_time = times_1), + times_1 + ) expect_warning( - dyn_multi <- select_eval_time(met_mix_dyn_all, eval_time = times_2, single = TRUE) + dyn_multi <- first_eval_time(met_mix_dyn_all, eval_time = times_2) ) expect_equal(dyn_multi, times_2[1]) @@ -183,37 +200,30 @@ test_that("selecting single eval time - mixed metric sets - integrated first", { times_2 <- as.numeric(5:4) / 7 # ---------------------------------------------------------------------------- - # integrated is first but includes static. Should return NULL and add error - # if <2 times are given + # integrated is first but includes static. Should return NULL and add warning + # if 1+ times are given - expect_snapshot( - select_eval_time(met_mix_int, eval_time = NULL, single = TRUE), - error = TRUE - ) - expect_snapshot( - select_eval_time(met_mix_int, eval_time = times_1, single = TRUE), - error = TRUE + expect_null(first_eval_time(met_mix_int, eval_time = NULL)) + + expect_warning( + first_eval_time(met_mix_int, eval_time = times_1) ) - expect_silent( - int_multi <- select_eval_time(met_mix_int, eval_time = times_2, single = TRUE) + expect_warning( + int_multi <- first_eval_time(met_mix_int, eval_time = times_2) ) expect_null(int_multi) # ---------------------------------------------------------------------------- # integrated is first but includes static and dynamic. Should return NULL and - # add error if <2 times are given + # add warning if 1+ times are given - expect_snapshot( - select_eval_time(met_mix_int_all, eval_time = NULL, single = TRUE), - error = TRUE - ) - expect_snapshot( - select_eval_time(met_mix_int_all, eval_time = times_1, single = TRUE), - error = TRUE + expect_null(first_eval_time(met_mix_int_all, eval_time = NULL)) + + expect_warning( + first_eval_time(met_mix_int_all, eval_time = times_1) ) - expect_silent( - int_multi <- select_eval_time(met_mix_int_all, eval_time = times_2, single = TRUE) + expect_warning( + int_multi <- first_eval_time(met_mix_int_all, eval_time = times_2) ) expect_null(int_multi) - }) From e3e43b803869af6685e1bbc135b8065e5b6c2381 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 Nov 2023 17:06:41 -0500 Subject: [PATCH 4/6] remove a function and add docs --- NAMESPACE | 2 -- R/metric-selection.R | 30 +++++-------------- man/first_metric.Rd | 23 ++++++++++++++ .../test-eval-time-single-selection.R | 13 ++++++++ 4 files changed, 43 insertions(+), 25 deletions(-) create mode 100644 man/first_metric.Rd diff --git a/NAMESPACE b/NAMESPACE index 91431de6e..a8efca1b6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -149,8 +149,6 @@ export(.stash_last_result) export(.use_case_weights_with_yardstick) export(augment) export(autoplot) -export(check_chosen_metric) -export(check_eval_time) export(check_initial) export(check_metrics) export(check_parameters) diff --git a/R/metric-selection.R b/R/metric-selection.R index 207249b29..880bc0b7f 100644 --- a/R/metric-selection.R +++ b/R/metric-selection.R @@ -1,22 +1,16 @@ -# For iterative search and racing, what metric will be optimized? +#' Tools for selecting metrics and evaluation times +#' +#' @param mtr_set A [yardstick::metric_set()]. +#' @param metric A character value for which metric is being used. +#' @param eval_time An optional vector of times to compute dynamic and/or +#' integrated metrics. #' @keywords internal #' @export first_metric <- function(mtr_set) { tibble::as_tibble(mtr_set)[1,] } -# Did the user pass an improper metric (i.e. want rmse but not computed)? -#' @keywords internal -#' @export -check_chosen_metric <- function(metric, mtr_set) { - mtr_info <- tibble::as_tibble(mtr_set) - in_set <- any(mtr_info$metric == metric) - if (!in_set) { - cli::cli_abort("metric '{metric}' is not in the metric set.") - } - invisible(TRUE) -} - +#' @rdname first_metric #' @keywords internal #' @export first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) { @@ -55,13 +49,3 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) { eval_time } - -#' @keywords internal -#' @export -check_eval_time <- function(eval_time = NULL, all_times = NULL) { - if (!is.null(eval_time)) { - return(eval_time) - } - - all_times <- sort(unique(all_times)) -} diff --git a/man/first_metric.Rd b/man/first_metric.Rd new file mode 100644 index 000000000..169d9b423 --- /dev/null +++ b/man/first_metric.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-selection.R +\name{first_metric} +\alias{first_metric} +\alias{first_eval_time} +\title{Tools for selecting metrics and evaluation times} +\usage{ +first_metric(mtr_set) + +first_eval_time(mtr_set, metric = NULL, eval_time = NULL) +} +\arguments{ +\item{mtr_set}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}}.} + +\item{metric}{A character value for which metric is being used.} + +\item{eval_time}{An optional vector of times to compute dynamic and/or +integrated metrics.} +} +\description{ +Tools for selecting metrics and evaluation times +} +\keyword{internal} diff --git a/tests/testthat/test-eval-time-single-selection.R b/tests/testthat/test-eval-time-single-selection.R index eb2b104fc..6ea4c2138 100644 --- a/tests/testthat/test-eval-time-single-selection.R +++ b/tests/testthat/test-eval-time-single-selection.R @@ -227,3 +227,16 @@ test_that("selecting single eval time - mixed metric sets - integrated first", { ) expect_null(int_multi) }) + + +test_that("selecting the first metric", { + library(yardstick) + + met_1 <- metric_set(rmse) + tbl_1 <- as_tibble(met_1)[1,] + met_2 <- metric_set(rmse, ccc) + tbl_2 <- as_tibble(met_2)[1,] + + expect_equal(first_metric(met_1), tbl_1) + expect_equal(first_metric(met_2), tbl_2) +}) From 743a5400b867cf34906c1eb086ec0bd1b0396fd5 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Sat, 2 Dec 2023 12:49:23 -0500 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Hannah Frick --- .../test-eval-time-single-selection.R | 60 +++++-------------- 1 file changed, 16 insertions(+), 44 deletions(-) diff --git a/tests/testthat/test-eval-time-single-selection.R b/tests/testthat/test-eval-time-single-selection.R index 6ea4c2138..b241dcdb0 100644 --- a/tests/testthat/test-eval-time-single-selection.R +++ b/tests/testthat/test-eval-time-single-selection.R @@ -1,20 +1,11 @@ -# "selecting single eval time" means how functions like `show_best()` will pick -# an evaluation time for a dynamic metric when none is given. Previously we -# would find what is in the data and select a time that was close to the median -# time. This was fine but inconsistent with other parts of tidymodels that do -# similar operations. For example, tune_bayes has to have a metric to optimize -# on so it uses the first metric in the metric set and, if needed, the first -# evaluation time given to the function. +library(yardstick) test_that("selecting single eval time - non-survival case", { - library(yardstick) - met_reg <- metric_set(rmse) times_1 <- 1 / 3 times_2 <- as.numeric(5:4) / 7 - # ---------------------------------------------------------------------------- # eval time is not applicable outside of survival models; return null expect_null(first_eval_time(met_reg, eval_time = NULL)) @@ -24,8 +15,6 @@ test_that("selecting single eval time - non-survival case", { }) test_that("selecting single eval time - pure metric sets", { - library(yardstick) - met_int <- metric_set(brier_survival_integrated) met_dyn <- metric_set(brier_survival) met_stc <- metric_set(concordance_survival) @@ -33,23 +22,21 @@ test_that("selecting single eval time - pure metric sets", { times_1 <- 1 / 3 times_2 <- as.numeric(5:4) / 7 - # ---------------------------------------------------------------------------- # all static; return NULL and add warning if times are given expect_null(first_eval_time(met_stc, eval_time = NULL)) expect_null(first_eval_time(met_stc, "concordance_survival", eval_time = NULL)) - expect_snapshot_warning( + expect_snapshot( stc_one <- first_eval_time(met_stc, eval_time = times_1) ) expect_null(stc_one) - expect_snapshot_warning( + expect_snapshot( stc_multi <- first_eval_time(met_stc, eval_time = times_2) ) expect_null(stc_multi) - # ---------------------------------------------------------------------------- # all dynamic; return a single time and warn if there are more and error if # there are none @@ -67,12 +54,11 @@ test_that("selecting single eval time - pure metric sets", { times_1 ) - expect_snapshot_warning( + expect_snapshot( dyn_multi <- first_eval_time(met_dyn, eval_time = times_2) ) expect_equal(dyn_multi, times_2[1]) - # ---------------------------------------------------------------------------- # all integrated; return NULL and warn if there 1+ times expect_null(first_eval_time(met_int, eval_time = NULL)) @@ -80,12 +66,12 @@ test_that("selecting single eval time - pure metric sets", { first_eval_time(met_int, "brier_survival_integrated", eval_time = NULL) ) - expect_warning( + expect_snapshot( int_1 <- first_eval_time(met_int, eval_time = times_1) ) expect_null(int_1) - expect_warning( + expect_snapshot( int_multi <- first_eval_time(met_int, eval_time = times_2) ) expect_null(int_multi) @@ -93,15 +79,12 @@ test_that("selecting single eval time - pure metric sets", { }) test_that("selecting single eval time - mixed metric sets - static first", { - library(yardstick) - met_mix_stc <- metric_set(concordance_survival, brier_survival) met_mix_stc_all <- metric_set(concordance_survival, brier_survival, brier_survival_integrated) times_1 <- 1 / 3 times_2 <- as.numeric(5:4) / 7 - # ---------------------------------------------------------------------------- # static is first but includes dynamic. Should return NULL and add warning # if times are given @@ -109,17 +92,16 @@ test_that("selecting single eval time - mixed metric sets - static first", { first_eval_time(met_mix_stc, eval_time = NULL) ) - expect_warning( + expect_snapshot( stc_1 <- first_eval_time(met_mix_stc, eval_time = times_1) ) expect_null(stc_1) - expect_warning( + expect_snapshot( stc_multi <- first_eval_time(met_mix_stc, eval_time = times_2) ) expect_null(stc_multi) - # ---------------------------------------------------------------------------- # static is first but includes dynamic and integrated. Should return NULL and # add warning if times are given @@ -127,20 +109,18 @@ test_that("selecting single eval time - mixed metric sets - static first", { first_eval_time(met_mix_stc_all, eval_time = NULL) ) - expect_warning( + expect_snapshot( stc_1 <- first_eval_time(met_mix_stc_all, eval_time = times_1) ) expect_null(stc_1) - expect_warning( + expect_snapshot( stc_multi <- first_eval_time(met_mix_stc_all, eval_time = times_2) ) expect_null(stc_multi) }) test_that("selecting single eval time - mixed metric sets - dynamic first", { - library(yardstick) - met_mix_dyn <- metric_set(brier_survival, concordance_survival) met_mix_dyn_all <- metric_set(brier_survival, @@ -150,7 +130,6 @@ test_that("selecting single eval time - mixed metric sets - dynamic first", { times_1 <- 1 / 3 times_2 <- as.numeric(5:4) / 7 - # ---------------------------------------------------------------------------- # dynamic is first but includes static. Should return single time and add warning # if 2+ times are given @@ -162,12 +141,11 @@ test_that("selecting single eval time - mixed metric sets - dynamic first", { first_eval_time(met_mix_dyn, eval_time = times_1), times_1 ) - expect_warning( + expect_snapshot( dyn_multi <- first_eval_time(met_mix_dyn, eval_time = times_2) ) expect_equal(dyn_multi, times_2[1]) - # ---------------------------------------------------------------------------- # dynamic is first but includes static and integrated. Should return single # time and add warning if 2+ times are given @@ -179,7 +157,7 @@ test_that("selecting single eval time - mixed metric sets - dynamic first", { first_eval_time(met_mix_dyn_all, eval_time = times_1), times_1 ) - expect_warning( + expect_snapshot( dyn_multi <- first_eval_time(met_mix_dyn_all, eval_time = times_2) ) expect_equal(dyn_multi, times_2[1]) @@ -188,8 +166,6 @@ test_that("selecting single eval time - mixed metric sets - dynamic first", { test_that("selecting single eval time - mixed metric sets - integrated first", { - library(yardstick) - met_mix_int <- metric_set(brier_survival_integrated, concordance_survival) met_mix_int_all <- metric_set(brier_survival_integrated, @@ -199,30 +175,28 @@ test_that("selecting single eval time - mixed metric sets - integrated first", { times_1 <- 1 / 3 times_2 <- as.numeric(5:4) / 7 - # ---------------------------------------------------------------------------- # integrated is first but includes static. Should return NULL and add warning # if 1+ times are given expect_null(first_eval_time(met_mix_int, eval_time = NULL)) - expect_warning( + expect_snapshot( first_eval_time(met_mix_int, eval_time = times_1) ) - expect_warning( + expect_snapshot( int_multi <- first_eval_time(met_mix_int, eval_time = times_2) ) expect_null(int_multi) - # ---------------------------------------------------------------------------- # integrated is first but includes static and dynamic. Should return NULL and # add warning if 1+ times are given expect_null(first_eval_time(met_mix_int_all, eval_time = NULL)) - expect_warning( + expect_snapshot( first_eval_time(met_mix_int_all, eval_time = times_1) ) - expect_warning( + expect_snapshot( int_multi <- first_eval_time(met_mix_int_all, eval_time = times_2) ) expect_null(int_multi) @@ -230,8 +204,6 @@ test_that("selecting single eval time - mixed metric sets - integrated first", { test_that("selecting the first metric", { - library(yardstick) - met_1 <- metric_set(rmse) tbl_1 <- as_tibble(met_1)[1,] met_2 <- metric_set(rmse, ccc) From 90d5d8a12179b85168389e7dea9b3cc78758da6f Mon Sep 17 00:00:00 2001 From: topepo Date: Sat, 2 Dec 2023 12:56:15 -0500 Subject: [PATCH 6/6] update unit test and snapshots --- .../_snaps/eval-time-single-selection.md | 118 +++++++++++++++++- .../test-eval-time-single-selection.R | 4 +- 2 files changed, 117 insertions(+), 5 deletions(-) diff --git a/tests/testthat/_snaps/eval-time-single-selection.md b/tests/testthat/_snaps/eval-time-single-selection.md index 0a1829035..5d122bb3c 100644 --- a/tests/testthat/_snaps/eval-time-single-selection.md +++ b/tests/testthat/_snaps/eval-time-single-selection.md @@ -1,10 +1,18 @@ # selecting single eval time - pure metric sets - Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + Code + stc_one <- first_eval_time(met_stc, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. --- - Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + Code + stc_multi <- first_eval_time(met_stc, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. --- @@ -24,7 +32,59 @@ --- - 2 evaluation times were available; the first (0.714) will be used. + Code + dyn_multi <- first_eval_time(met_dyn, eval_time = times_2) + Condition + Warning: + 2 evaluation times were available; the first (0.714) will be used. + +--- + + Code + int_1 <- first_eval_time(met_int, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + int_multi <- first_eval_time(met_int, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +# selecting single eval time - mixed metric sets - static first + + Code + stc_1 <- first_eval_time(met_mix_stc, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_multi <- first_eval_time(met_mix_stc, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_1 <- first_eval_time(met_mix_stc_all, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_multi <- first_eval_time(met_mix_stc_all, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. # selecting single eval time - mixed metric sets - dynamic first @@ -34,6 +94,14 @@ Error in `first_eval_time()`: ! A single evaluation time is required to use this metric. +--- + + Code + dyn_multi <- first_eval_time(met_mix_dyn, eval_time = times_2) + Condition + Warning: + 2 evaluation times were available; the first (0.714) will be used. + --- Code @@ -42,3 +110,47 @@ Error in `first_eval_time()`: ! A single evaluation time is required to use this metric. +--- + + Code + dyn_multi <- first_eval_time(met_mix_dyn_all, eval_time = times_2) + Condition + Warning: + 2 evaluation times were available; the first (0.714) will be used. + +# selecting single eval time - mixed metric sets - integrated first + + Code + first_eval_time(met_mix_int, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + Output + NULL + +--- + + Code + int_multi <- first_eval_time(met_mix_int, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + first_eval_time(met_mix_int_all, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + Output + NULL + +--- + + Code + int_multi <- first_eval_time(met_mix_int_all, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + diff --git a/tests/testthat/test-eval-time-single-selection.R b/tests/testthat/test-eval-time-single-selection.R index b241dcdb0..c40621b18 100644 --- a/tests/testthat/test-eval-time-single-selection.R +++ b/tests/testthat/test-eval-time-single-selection.R @@ -205,9 +205,9 @@ test_that("selecting single eval time - mixed metric sets - integrated first", { test_that("selecting the first metric", { met_1 <- metric_set(rmse) - tbl_1 <- as_tibble(met_1)[1,] + tbl_1 <- tibble::as_tibble(met_1)[1,] met_2 <- metric_set(rmse, ccc) - tbl_2 <- as_tibble(met_2)[1,] + tbl_2 <- tibble::as_tibble(met_2)[1,] expect_equal(first_metric(met_1), tbl_1) expect_equal(first_metric(met_2), tbl_2)