Skip to content

Commit

Permalink
Fix issue #215 (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
fweber144 authored Jan 14, 2022
1 parent e37cf76 commit d7c64e4
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 95 deletions.
75 changes: 37 additions & 38 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,6 @@ proj_predict_aux <- function(proj, mu, weights, ...) {
#'
#' @inheritParams summary.vsel
#' @param x An object of class `vsel` (returned by [varsel()] or [cv_varsel()]).
#' @param baseline Either `"ref"` or `"best"` indicating whether the baseline is
#' the reference model or the best submodel (in terms of `stats[1]`),
#' respectively.
#'
#' @examples
#' if (requireNamespace("rstanarm", quietly = TRUE)) {
Expand Down Expand Up @@ -329,10 +326,9 @@ plot.vsel <- function(
## compute all the statistics and fetch only those that were asked
nfeat_baseline <- .get_nfeat_baseline(object, baseline, stats[1])
tab <- rbind(
.tabulate_stats(object, stats,
alpha = alpha,
nfeat_baseline = nfeat_baseline),
.tabulate_stats(object, stats, alpha = alpha)
.tabulate_stats(object, stats, alpha = alpha,
nfeat_baseline = nfeat_baseline, ...),
.tabulate_stats(object, stats, alpha = alpha, ...)
)
stats_table <- subset(tab, tab$delta == deltas)
stats_ref <- subset(stats_table, stats_table$size == Inf)
Expand Down Expand Up @@ -430,10 +426,12 @@ plot.vsel <- function(
#' * `"mlpd"`: mean log predictive density, that is, `"elpd"` divided by the
#' number of observations.
#' * `"mse"`: mean squared error.
#' * `"rmse"`: root mean squared error.
#' * `"rmse"`: root mean squared error. For the corresponding standard error,
#' bootstrapping is used.
#' * `"acc"` (or its alias, `"pctcorr"`): classification accuracy
#' ([binomial()] family only).
#' * `"auc"`: area under the ROC curve ([binomial()] family only).
#' * `"auc"`: area under the ROC curve ([binomial()] family only). For the
#' corresponding standard error, bootstrapping is used.
#' @param type One or more items from `"mean"`, `"se"`, `"lower"`, `"upper"`,
#' `"diff"`, and `"diff.se"` indicating which of these to compute for each
#' item from `stats` (mean, standard error, lower and upper confidence
Expand All @@ -449,10 +447,14 @@ plot.vsel <- function(
#' normal-approximation confidence intervals. For example, `alpha = 0.32`
#' corresponds to a coverage of 68%, i.e., one-standard-error intervals
#' (because of the normal approximation).
#' @param baseline Only relevant if `deltas` is `TRUE`. Either `"ref"` or
#' `"best"` indicating whether the baseline is the reference model or the best
#' submodel (in terms of `stats[1]`), respectively.
#' @param ... Currently ignored.
#' @param baseline For [summary.vsel()]: Only relevant if `deltas` is `TRUE`.
#' For [plot.vsel()]: Always relevant. Either `"ref"` or `"best"`, indicating
#' whether the baseline is the reference model or the best submodel found (in
#' terms of `stats[1]`), respectively.
#' @param ... Arguments passed to the internal function which is used for
#' bootstrapping (if applicable; see argument `stats`). Currently, relevant
#' arguments are `b` (the number of bootstrap samples, defaulting to `2000`)
#' and `seed` (see [set.seed()], defaulting to `NULL`).
#'
#' @examples
#' if (requireNamespace("rstanarm", quietly = TRUE)) {
Expand Down Expand Up @@ -513,9 +515,9 @@ summary.vsel <- function(
if (deltas) {
nfeat_baseline <- .get_nfeat_baseline(object, baseline, stats[1])
tab <- .tabulate_stats(object, stats, alpha = alpha,
nfeat_baseline = nfeat_baseline)
nfeat_baseline = nfeat_baseline, ...)
} else {
tab <- .tabulate_stats(object, stats, alpha = alpha)
tab <- .tabulate_stats(object, stats, alpha = alpha, ...)
}
stats_table <- subset(tab, tab$size != Inf) %>%
dplyr::group_by(.data$statistic) %>%
Expand Down Expand Up @@ -642,12 +644,15 @@ print.vselsummary <- function(x, digits = 1, ...) {
#' @param ... Further arguments passed to [summary.vsel()] (apart from
#' argument `digits` which is passed to [print.vselsummary()]).
#'
#' @return The `data.frame` returned by [summary.vsel()] (invisible).
#' @return The output of [summary.vsel()] (invisible).
#'
#' @export
print.vsel <- function(x, ...) {
stats <- summary.vsel(x, ...)
print(stats, ...)
dot_args <- list(...)
stats <- do.call(summary.vsel, c(list(object = x),
dot_args[names(dot_args) != "digits"]))
do.call(print, c(list(x = stats),
dot_args[names(dot_args) == "digits"]))
return(invisible(stats))
}

Expand All @@ -663,35 +668,30 @@ print.vsel <- function(x, ...) {
#' [cv_varsel()]).
#' @param stat Statistic used for the decision. See [summary.vsel()] for
#' possible choices.
#' @param alpha A number determining the (nominal) coverage `1 - alpha` of the
#' normal-approximation confidence intervals based on which the decision is
#' made. For example, `alpha = 0.32` corresponds to a coverage of 68%, i.e.,
#' one-standard-error intervals (because of the normal approximation). See
#' section "Details" below for more information.
#' @param pct A number giving the relative proportion (*not* percents) between
#' baseline model and null model utilities one is willing to sacrifice. See
#' section "Details" below for more information.
#' @param type Either `"upper"` or `"lower"` determining whether the decision is
#' based on the upper or lower confidence interval bound, respectively. See
#' section "Details" below for more information.
#' @param baseline Either `"ref"` or `"best"` indicating whether the baseline is
#' the reference model or the best submodel (in terms of `stat[1]`),
#' respectively.
#' @param warnings Mainly for internal use. A single logical value indicating
#' whether to throw warnings if automatic suggestion fails. Usually there is
#' no reason to set this to `FALSE`.
#' @param ... Currently ignored.
#' @param ... Arguments passed to [summary.vsel()], except for `object`, `stats`
#' (which is set to `stat`), `type`, and `deltas` (which is set to `TRUE`).
#' See section "Details" below for some important arguments which may be
#' passed here.
#'
#' @details The suggested model size is the smallest model size for which either
#' the lower or upper bound (depending on argument `type`) of the
#' normal-approximation confidence interval (with nominal coverage `1 -
#' alpha`) for \eqn{u_k - u_{\mbox{base}}}{u_k - u_base} (with \eqn{u_k}
#' denoting the \eqn{k}-th submodel's utility and
#' \eqn{u_{\mbox{base}}}{u_base} denoting the baseline model's utility) falls
#' above (or is equal to) \deqn{\mbox{pct} * (u_0 - u_{\mbox{base}})}{pct *
#' (u_0 - u_base)} where \eqn{u_0} denotes the null model utility. The
#' baseline is either the reference model or the best submodel found (see
#' argument `baseline`).
#' alpha`, see argument `alpha` of [summary.vsel()]) for \eqn{u_k -
#' u_{\mbox{base}}}{u_k - u_base} (with \eqn{u_k} denoting the \eqn{k}-th
#' submodel's utility and \eqn{u_{\mbox{base}}}{u_base} denoting the baseline
#' model's utility) falls above (or is equal to) \deqn{\mbox{pct} * (u_0 -
#' u_{\mbox{base}})}{pct * (u_0 - u_base)} where \eqn{u_0} denotes the null
#' model utility. The baseline is either the reference model or the best
#' submodel found (see argument `baseline` of [summary.vsel()]).
#'
#' For example, `alpha = 0.32`, `pct = 0`, and `type = "upper"` means that we
#' select the smallest model size for which the upper bound of the confidence
Expand Down Expand Up @@ -742,10 +742,8 @@ suggest_size <- function(object, ...) {
suggest_size.vsel <- function(
object,
stat = "elpd",
alpha = 0.32,
pct = 0,
type = "upper",
baseline = if (!inherits(object$refmodel, "datafit")) "ref" else "best",
warnings = TRUE,
...
) {
Expand All @@ -771,9 +769,10 @@ suggest_size.vsel <- function(
}
bound <- type
stats <- summary.vsel(object,
stats = stat, alpha = alpha,
stats = stat,
type = c("mean", "upper", "lower"),
baseline = baseline, deltas = TRUE)$selection
deltas = TRUE,
...)$selection
util_null <- sgn * unlist(unname(subset(
stats, stats$size == 0,
paste0(stat, suffix)
Expand Down
2 changes: 1 addition & 1 deletion R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ auc <- function(x) {
# Bootstrap an arbitrary quantity `fun` that takes the sample `x` as the first
# input. Other arguments of `fun` can be passed by `...`. Example:
# `boostrap(x, mean)`.
bootstrap <- function(x, fun = mean, b = 1000, seed = NULL, ...) {
bootstrap <- function(x, fun = mean, b = 2000, seed = NULL, ...) {
# set random seed but ensure the old RNG state is restored on exit
if (exists(".Random.seed")) {
rng_state_old <- .Random.seed
Expand Down
29 changes: 13 additions & 16 deletions R/summary_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
# statistics relative to the baseline model of that size (`nfeat_baseline = Inf`
# means that the baseline model is the reference model).
.tabulate_stats <- function(varsel, stats, alpha = 0.05,
nfeat_baseline = NULL) {
nfeat_baseline = NULL, ...) {
stat_tab <- data.frame()
summ_ref <- varsel$summaries$ref
summ_sub <- varsel$summaries$sub
Expand Down Expand Up @@ -75,7 +75,7 @@
## reference model statistics
summ <- summ_ref
res <- get_stat(summ$mu, summ$lppd, varsel$d_test, stat, mu.bs = mu.bs,
lppd.bs = lppd.bs, weights = summ$w, alpha = alpha)
lppd.bs = lppd.bs, weights = summ$w, alpha = alpha, ...)
row <- data.frame(
data = varsel$d_test$type, size = Inf, delta = delta, statistic = stat,
value = res$value, lq = res$lq, uq = res$uq, se = res$se, diff = NA,
Expand All @@ -93,10 +93,10 @@
## scale
res_ref <- get_stat(summ_ref$mu, summ_ref$lppd, varsel$d_test,
stat, mu.bs = NULL, lppd.bs = NULL,
weights = summ_ref$w, alpha = alpha)
weights = summ_ref$w, alpha = alpha, ...)
res_diff <- get_stat(summ$mu, summ$lppd, varsel$d_test, stat,
mu.bs = summ_ref$mu, lppd.bs = summ_ref$lppd,
weights = summ$w, alpha = alpha)
weights = summ$w, alpha = alpha, ...)
val <- res_ref$value + res_diff$value
val.se <- sqrt(res_ref$se^2 + res_diff$se^2)
lq <- qnorm(alpha / 2, mean = val, sd = val.se)
Expand All @@ -109,10 +109,10 @@
} else {
## normal case
res <- get_stat(summ$mu, summ$lppd, varsel$d_test, stat, mu.bs = mu.bs,
lppd.bs = lppd.bs, weights = summ$w, alpha = alpha)
lppd.bs = lppd.bs, weights = summ$w, alpha = alpha, ...)
diff <- get_stat(summ$mu, summ$lppd, varsel$d_test, stat,
mu.bs = summ_ref$mu, lppd.bs = summ_ref$lppd,
weights = summ$w, alpha = alpha)
weights = summ$w, alpha = alpha, ...)
row <- data.frame(
data = varsel$d_test$type, size = k - 1, delta = delta,
statistic = stat, value = res$value, lq = res$lq, uq = res$uq,
Expand All @@ -127,7 +127,7 @@
}

get_stat <- function(mu, lppd, d_test, stat, mu.bs = NULL, lppd.bs = NULL,
weights = NULL, alpha = 0.1, seed = 1208499, B = 2000) {
weights = NULL, alpha = 0.1, ...) {
##
## Calculates given statistic stat with standard error and confidence bounds.
## mu.bs and lppd.bs are the pointwise mu and lppd for another model that is
Expand Down Expand Up @@ -195,16 +195,14 @@ get_stat <- function(mu, lppd, d_test, stat, mu.bs = NULL, lppd.bs = NULL,
function(resid2) {
sqrt(mean(weights * resid2, na.rm = TRUE))
},
b = B,
seed = seed
...
)
value.bootstrap2 <- bootstrap(
(mu.bs - y)^2,
function(resid2) {
sqrt(mean(weights * resid2, na.rm = TRUE))
},
b = B,
seed = seed
...
)
value.se <- sd(value.bootstrap1 - value.bootstrap2)
} else {
Expand All @@ -214,8 +212,7 @@ get_stat <- function(mu, lppd, d_test, stat, mu.bs = NULL, lppd.bs = NULL,
function(resid2) {
sqrt(mean(weights * resid2, na.rm = TRUE))
},
b = B,
seed = seed
...
)
value.se <- sd(value.bootstrap)
}
Expand All @@ -240,12 +237,12 @@ get_stat <- function(mu, lppd, d_test, stat, mu.bs = NULL, lppd.bs = NULL,
mu[is.na(mu.bs)] <- NA # for which both mu and mu.bs are non-NA
auc.data.bs <- cbind(y, mu.bs, weights)
value <- auc(auc.data) - auc(auc.data.bs)
value.bootstrap1 <- bootstrap(auc.data, auc, b = B, seed = seed)
value.bootstrap2 <- bootstrap(auc.data.bs, auc, b = B, seed = seed)
value.bootstrap1 <- bootstrap(auc.data, auc, ...)
value.bootstrap2 <- bootstrap(auc.data.bs, auc, ...)
value.se <- sd(value.bootstrap1 - value.bootstrap2, na.rm = TRUE)
} else {
value <- auc(auc.data)
value.bootstrap <- bootstrap(auc.data, auc, b = B, seed = seed)
value.bootstrap <- bootstrap(auc.data, auc, ...)
value.se <- sd(value.bootstrap, na.rm = TRUE)
}
}
Expand Down
18 changes: 12 additions & 6 deletions man/plot.vsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/print.vsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 11 additions & 20 deletions man/suggest_size.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d7c64e4

Please sign in to comment.