Skip to content

Commit

Permalink
Fix issue #154. (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
fweber144 authored Jan 14, 2022
1 parent 231b1ce commit e37cf76
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 66 deletions.
28 changes: 14 additions & 14 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ cv_varsel.refmodel <- function(
nclusters = NULL,
ndraws_pred = 400,
nclusters_pred = NULL,
cv_search = !inherits(object, "datafit"),
refit_prj = !inherits(object, "datafit"),
nterms_max = NULL,
penalty = NULL,
verbose = TRUE,
Expand All @@ -126,13 +126,13 @@ cv_varsel.refmodel <- function(
refmodel <- object
## resolve the arguments similar to varsel
args <- parse_args_varsel(
refmodel = refmodel, method = method, cv_search = cv_search,
refmodel = refmodel, method = method, refit_prj = refit_prj,
nterms_max = nterms_max, nclusters = nclusters,
ndraws = ndraws, nclusters_pred = nclusters_pred, ndraws_pred = ndraws_pred,
search_terms = search_terms
)
method <- args$method
cv_search <- args$cv_search
refit_prj <- args$refit_prj
nterms_max <- args$nterms_max
nclusters <- args$nclusters
ndraws <- args$ndraws
Expand All @@ -154,7 +154,7 @@ cv_varsel.refmodel <- function(
sel_cv <- loo_varsel(
refmodel = refmodel, method = method, nterms_max = nterms_max,
ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred,
nclusters_pred = nclusters_pred, cv_search = cv_search, penalty = penalty,
nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty,
verbose = verbose, opt = opt, nloo = nloo,
validate_search = validate_search, seed = seed,
search_terms = search_terms
Expand All @@ -163,7 +163,7 @@ cv_varsel.refmodel <- function(
sel_cv <- kfold_varsel(
refmodel = refmodel, method = method, nterms_max = nterms_max,
ndraws = ndraws, nclusters = nclusters, ndraws_pred = ndraws_pred,
nclusters_pred = nclusters_pred, cv_search = cv_search, penalty = penalty,
nclusters_pred = nclusters_pred, refit_prj = refit_prj, penalty = penalty,
verbose = verbose, opt = opt, K = K, seed = seed,
search_terms = search_terms
)
Expand All @@ -179,7 +179,7 @@ cv_varsel.refmodel <- function(
sel <- varsel(refmodel,
method = method, ndraws = ndraws, nclusters = nclusters,
ndraws_pred = ndraws_pred, nclusters_pred = nclusters_pred,
cv_search = cv_search, nterms_max = nterms_max - 1,
refit_prj = refit_prj, nterms_max = nterms_max - 1,
penalty = penalty, verbose = verbose,
lambda_min_ratio = lambda_min_ratio, nlambda = nlambda,
regul = regul, search_terms = search_terms, seed = seed)
Expand Down Expand Up @@ -287,7 +287,7 @@ parse_args_cv_varsel <- function(refmodel, cv_method, K) {
}

loo_varsel <- function(refmodel, method, nterms_max, ndraws,
nclusters, ndraws_pred, nclusters_pred, cv_search,
nclusters, ndraws_pred, nclusters_pred, refit_prj,
penalty, verbose, opt, nloo = NULL,
validate_search = TRUE, seed = NULL,
search_terms = NULL) {
Expand Down Expand Up @@ -379,7 +379,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
search_path = search_path,
nterms = c(0, seq_along(search_path$solution_terms)),
p_ref = p_pred, refmodel = refmodel, regul = opt$regul,
cv_search = cv_search
refit_prj = refit_prj
)

if (verbose) {
Expand Down Expand Up @@ -459,7 +459,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
search_path = search_path,
nterms = c(0, seq_along(search_path$solution_terms)),
p_ref = p_pred, refmodel = refmodel, regul = opt$regul,
cv_search = cv_search
refit_prj = refit_prj
)
summaries_sub <- .get_sub_summaries(
submodels = submodels, test_points = c(i), refmodel = refmodel
Expand Down Expand Up @@ -512,7 +512,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,

kfold_varsel <- function(refmodel, method, nterms_max, ndraws,
nclusters, ndraws_pred, nclusters_pred,
cv_search, penalty, verbose, opt, K, seed = NULL,
refit_prj, penalty, verbose, opt, K, seed = NULL,
search_terms = NULL) {
# Fetch the K reference model fits (or fit them now if not already done) and
# create objects of class `refmodel` from them (and also store the `omitted`
Expand Down Expand Up @@ -571,7 +571,7 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws,

# Re-project along the solution path (or fetch the projections from the search
# results) for each fold:
if (verbose && cv_search) {
if (verbose && refit_prj) {
print("Computing projections..")
pb <- utils::txtProgressBar(min = 0, max = K, style = 3, initial = 0)
}
Expand All @@ -581,16 +581,16 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws,
search_path = search_path,
nterms = c(0, seq_along(search_path$solution_terms)),
p_ref = fold$p_pred, refmodel = fold$refmodel, regul = opt$regul,
cv_search = cv_search
refit_prj = refit_prj
)
if (verbose && cv_search) {
if (verbose && refit_prj) {
utils::setTxtProgressBar(pb, fold_index)
}
return(submodels)
}
submodels_cv <- mapply(get_submodels_cv, search_path_cv, seq_along(list_cv),
SIMPLIFY = FALSE)
if (verbose && cv_search) {
if (verbose && refit_prj) {
close(pb)
}

Expand Down
32 changes: 16 additions & 16 deletions R/project.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
#' of predictor terms for the submodel onto which the projection will be
#' performed. Argument `nterms` is ignored in that case. For an `object` which
#' is not of class `vsel`, `solution_terms` must not be `NULL`.
#' @param cv_search A single logical value indicating whether to fit the
#' @param refit_prj A single logical value indicating whether to fit the
#' submodels (again) (`TRUE`) or to retrieve the fitted submodels from
#' `object` (`FALSE`). For an `object` which is not of class `vsel`,
#' `cv_search` must be `TRUE`.
#' @param ndraws Only relevant if `cv_search` is `TRUE`. Number of posterior
#' `refit_prj` must be `TRUE`.
#' @param ndraws Only relevant if `refit_prj` is `TRUE`. Number of posterior
#' draws to be projected. **Caution:** For `ndraws <= 20`, the value of
#' `ndraws` is passed to `nclusters` (so that clustering is used). Ignored if
#' `nclusters` is not `NULL` or if the reference model is of class `datafit`
#' (in which case one cluster is used). See also section "Details" below.
#' @param nclusters Only relevant if `cv_search` is `TRUE`. Number of clusters
#' @param nclusters Only relevant if `refit_prj` is `TRUE`. Number of clusters
#' of posterior draws to be projected. Ignored if the reference model is of
#' class `datafit` (in which case one cluster is used). For the meaning of
#' `NULL`, see argument `ndraws`. See also section "Details" below.
Expand Down Expand Up @@ -103,7 +103,7 @@
#'
#' @export
project <- function(object, nterms = NULL, solution_terms = NULL,
cv_search = TRUE, ndraws = 400, nclusters = NULL,
refit_prj = TRUE, ndraws = 400, nclusters = NULL,
seed = NULL, regul = 1e-4, ...) {
if (inherits(object, "datafit")) {
stop("project() does not support an `object` of class \"datafit\".")
Expand All @@ -112,32 +112,32 @@ project <- function(object, nterms = NULL, solution_terms = NULL,
stop("Please provide an `object` of class \"vsel\" or use argument ",
"`solution_terms`.")
}
if (!inherits(object, "vsel") && !cv_search) {
if (!inherits(object, "vsel") && !refit_prj) {
stop("Please provide an `object` of class \"vsel\" or use ",
"`cv_search = TRUE`.")
"`refit_prj = TRUE`.")
}

refmodel <- get_refmodel(object, ...)

if (cv_search && inherits(refmodel, "datafit")) {
warning("Automatically setting `cv_search` to `FALSE` since the reference ",
if (refit_prj && inherits(refmodel, "datafit")) {
warning("Automatically setting `refit_prj` to `FALSE` since the reference ",
"model is of class \"datafit\".")
cv_search <- FALSE
refit_prj <- FALSE
}

if (!cv_search &&
if (!refit_prj &&
!is.null(solution_terms) &&
any(
solution_terms(object)[seq_along(solution_terms)] != solution_terms
)) {
warning("The given `solution_terms` are not part of the solution path ",
"(from `solution_terms(object)`), so `cv_search` is automatically ",
"(from `solution_terms(object)`), so `refit_prj` is automatically ",
"set to `TRUE`.")
cv_search <- TRUE
refit_prj <- TRUE
}

if (!cv_search) {
warning("Currently, `cv_search = FALSE` requires some caution, see GitHub ",
if (!refit_prj) {
warning("Currently, `refit_prj = FALSE` requires some caution, see GitHub ",
"issues #168 and #211.")
}

Expand Down Expand Up @@ -224,7 +224,7 @@ project <- function(object, nterms = NULL, solution_terms = NULL,
submodls = object$search_path$submodls
),
nterms = nterms, p_ref = p_ref, refmodel = refmodel, regul = regul,
cv_search = cv_search
refit_prj = refit_prj
)

# Output:
Expand Down
4 changes: 2 additions & 2 deletions R/projfun.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ project_submodel <- function(solution_terms, p_ref, refmodel, regul = 1e-4) {
# sizes `nterms`. Returns a list of submodels (each processed by
# .init_submodel()).
.get_submodels <- function(search_path, nterms, p_ref, refmodel, regul,
cv_search = FALSE) {
if (!cv_search) {
refit_prj = FALSE) {
if (!refit_prj) {
# In this case, simply fetch the already computed projections, so don't
# project again.
fetch_submodel <- function(nterms) {
Expand Down
2 changes: 1 addition & 1 deletion R/refmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
#' returned by [project()] in its output element `submodl` (which in turn is
#' the same as the return value of `div_minimizer`, except if [project()]
#' was used with an `object` of class `vsel` based on an L1 search as well
#' as with `cv_search = FALSE`).
#' as with `refit_prj = FALSE`).
#' + `newdata` accepts data for new observations (at least in the form of a
#' `data.frame`).
#' * `div_minimizer` does not need to have a specific prototype, but it needs to
Expand Down
26 changes: 13 additions & 13 deletions R/varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#' L1 search and `"forward"` for forward search. If `NULL`, then `"forward"`
#' is used if the reference model has multilevel or additive terms and `"L1"`
#' otherwise. See also section "Details" below.
#' @param cv_search A single logical value indicating whether to fit the
#' @param refit_prj A single logical value indicating whether to fit the
#' submodels along the solution path again (`TRUE`) or to retrieve their fits
#' from the search part (`FALSE`) before using those (re-)fits in the
#' evaluation part.
Expand All @@ -30,12 +30,12 @@
#' part. Ignored in case of L1 search (because L1 search always uses a single
#' cluster). For the meaning of `NULL`, see argument `ndraws`. See also
#' section "Details" below.
#' @param ndraws_pred Only relevant if `cv_search` is `TRUE`. Number of
#' @param ndraws_pred Only relevant if `refit_prj` is `TRUE`. Number of
#' posterior draws used in the evaluation part. **Caution:** For `ndraws_pred
#' <= 20`, the value of `ndraws_pred` is passed to `nclusters_pred` (so that
#' clustering is used). Ignored if `nclusters_pred` is not `NULL`. See also
#' section "Details" below.
#' @param nclusters_pred Only relevant if `cv_search` is `TRUE`. Number of
#' @param nclusters_pred Only relevant if `refit_prj` is `TRUE`. Number of
#' clusters of posterior draws used in the evaluation part. For the meaning of
#' `NULL`, see argument `ndraws_pred`. See also section "Details" below.
#' @param nterms_max Maximum number of predictor terms until which the search is
Expand Down Expand Up @@ -140,7 +140,7 @@ varsel.default <- function(object, ...) {
varsel.refmodel <- function(object, d_test = NULL, method = NULL,
ndraws = 20, nclusters = NULL, ndraws_pred = 400,
nclusters_pred = NULL,
cv_search = !inherits(object, "datafit"),
refit_prj = !inherits(object, "datafit"),
nterms_max = NULL, verbose = TRUE,
lambda_min_ratio = 1e-5, nlambda = 150,
thresh = 1e-6, regul = 1e-4, penalty = NULL,
Expand All @@ -149,13 +149,13 @@ varsel.refmodel <- function(object, d_test = NULL, method = NULL,

## fetch the default arguments or replace them by the user defined values
args <- parse_args_varsel(
refmodel = refmodel, method = method, cv_search = cv_search,
refmodel = refmodel, method = method, refit_prj = refit_prj,
nterms_max = nterms_max, nclusters = nclusters, ndraws = ndraws,
nclusters_pred = nclusters_pred, ndraws_pred = ndraws_pred,
search_terms = search_terms
)
method <- args$method
cv_search <- args$cv_search
refit_prj <- args$refit_prj
nterms_max <- args$nterms_max
nclusters <- args$nclusters
ndraws <- args$ndraws
Expand Down Expand Up @@ -191,7 +191,7 @@ varsel.refmodel <- function(object, d_test = NULL, method = NULL,
submodels <- .get_submodels(search_path = search_path,
nterms = c(0, seq_along(solution_terms)),
p_ref = p_pred, refmodel = refmodel,
regul = regul, cv_search = cv_search)
regul = regul, refit_prj = refit_prj)
sub <- .get_sub_summaries(
submodels = submodels, test_points = seq_along(refmodel$y),
refmodel = refmodel
Expand Down Expand Up @@ -276,7 +276,7 @@ select <- function(method, p_sel, refmodel, nterms_max, penalty, verbose, opt,
}
}

parse_args_varsel <- function(refmodel, method, cv_search, nterms_max,
parse_args_varsel <- function(refmodel, method, refit_prj, nterms_max,
nclusters, ndraws, nclusters_pred, ndraws_pred,
search_terms) {
##
Expand Down Expand Up @@ -311,11 +311,11 @@ parse_args_varsel <- function(refmodel, method, cv_search, nterms_max,
stop("Unknown search method")
}

stopifnot(!is.null(cv_search))
if (cv_search && inherits(refmodel, "datafit")) {
warning("For an `object` of class \"datafit\", `cv_search` is ",
stopifnot(!is.null(refit_prj))
if (refit_prj && inherits(refmodel, "datafit")) {
warning("For an `object` of class \"datafit\", `refit_prj` is ",
"automatically set to `FALSE`.")
cv_search <- FALSE
refit_prj <- FALSE
}

stopifnot(!is.null(ndraws))
Expand Down Expand Up @@ -353,7 +353,7 @@ parse_args_varsel <- function(refmodel, method, cv_search, nterms_max,
nterms_max <- min(max_nv_possible, nterms_max + 1)

return(nlist(
method, cv_search, nterms_max, nclusters, ndraws, nclusters_pred,
method, refit_prj, nterms_max, nclusters, ndraws, nclusters_pred,
ndraws_pred, search_terms
))
}
8 changes: 4 additions & 4 deletions man/cv_varsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/project.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/refmodel-init-get.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e37cf76

Please sign in to comment.