Skip to content

Commit

Permalink
Fixes #78
Browse files Browse the repository at this point in the history
  • Loading branch information
spsanderson committed Jan 11, 2023
1 parent 1c71a2b commit ef5a1e0
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 0 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export(internal_make_fitted_wflw)
export(internal_make_spec_tbl)
export(internal_make_wflw)
export(internal_make_wflw_predictions)
export(internal_set_args_to_tune)
export(load_deps)
export(make_classification_base_tbl)
export(make_regression_base_tbl)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ None
1. Fix #73 - Add function `make_regression_base_tbl()`
2. Fix #74 - Add function `make_classification_base_tbl()`
3. Fix #77 - Add function `internal_make_spec_tbl()`
4. Fix #78 - Add function `internal_set_args_to_tune()`

## Minor Fixes and Improvements
1. Fix #72 - Update `fast_classification_parsnip_spec_tbl()` and
Expand Down
136 changes: 136 additions & 0 deletions R/internals-set-tune-modspec-args.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#' Internals Make a Tunable Model Specification
#'
#' @family Internals
#'
#' @author Steven P. Sanderson II, MPH
#'
#' @description Make a tuned model specification object.
#'
#' @details This will take a model specification that is created from a function
#' like [tidyAML::fast_regression_parsnip_spec_tbl()] and update the __model_spec__
#' `args` to `tune::tune()`. This is done dynamically, meaning you do not need
#' to know the names of the parameters inside of the model specification.
#'
#' @param .model_tbl The model table that is generated from a function like
#' `fast_regression_parsnip_spec_tbl()`, must have a class of "tidyaml_mod_spec_tbl".
#'
#' @examples
#' library(dplyr)
#'
#' mod_tbl <- fast_regression_parsnip_spec_tbl()
#' mod_tbl$model_spec[[1]]
#'
#' updated_tbl <- mod_tbl %>%
#' mutate(model_spec = internal_set_args_to_tune(mod_tbl))
#' updated_tbl$model_spec[[1]]
#'
#' @return
#' A list object of workflows.
#'
#' @name internal_set_args_to_tune
NULL

#' @export
#' @rdname internal_set_args_to_tune

internal_set_args_to_tune <- function(.model_tbl){

# Tidyeval
model_tbl <- .model_tbl

# Checks ----
if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
rlang::abort(
message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
use_cli_format = TRUE
)
}

model_tbl_with_params <- mod_tbl %>%
dplyr::mutate(
model_params = purrr::pmap(
dplyr::cur_data(),
~ list(formalArgs(..4))
)
)

models_list_new <- model_tbl_with_params %>%
dplyr::group_split(.model_id)

tuned_params_list <- models_list_new %>%
purrr::imap(
.f = function(obj, id){

# Pull the model params
mod_params <- obj %>% dplyr::pull(6) %>% purrr::pluck(1) # change to pull(6)
mod_params_list <- unlist(mod_params) %>% as.list()
#param_names <- unlist(mod_params)
names(mod_params_list) <- unlist(mod_params)

# Set mode and engine
p_mode <- obj %>% dplyr::pull(3) %>% purrr::pluck(1)
p_engine <- obj %>% dplyr::pull(2) %>% purrr::pluck(1)
me_list <- list(
mode = paste0("mode = ", p_mode),
engine = paste0("engine = ", p_engine)
)

# Get all other params
me_vec <- c("mode","engine")
pv <- unlist(mod_params)
params_to_modify <- pv[!pv %in% me_vec] %>% as.list()
names(params_to_modify) <- unlist(params_to_modify)

# Set each item equal to .x = tune::tune()
tuned_params_list <- purrr::map(
params_to_modify,
~ paste0("tune::tune()")
)

# use modifyList()
res <- utils::modifyList(mod_params_list, tuned_params_list)
res <- utils::modifyList(res, me_list)

# Return
return(res)

}
)

models_with_params_list <- purrr::map2(
.x = tuned_params_list,
.y = models_list_new,
~ {.y$model_params <- list(.x[.y$model_params[[1]][[1]]]);.y}
)

new_mod_obj <- models_with_params_list %>%
purrr::imap(
.f = function(obj, id){

# Get Model Specification
mod_spec <- obj %>%
dplyr::pull(5) %>%
purrr::pluck(1)

# Get the tuned params
new_mod_args <- obj %>%
dplyr::pull(6) %>%
purrr::pluck(1)

# Drop the ones we don't need to set
new_mod_args <- new_mod_args %>%
unlist() %>%
subset(!names(.) %in% c('mode','engine')) %>%
as.list()

# Set the new model arguments
mod_spec$args <- new_mod_args

# Return the newly modified model specification
return(mod_spec)
}
)

return(new_mod_obj)

}
1 change: 1 addition & 0 deletions man/internal_make_fitted_wflw.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/internal_make_spec_tbl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/internal_make_wflw.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/internal_make_wflw_predictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions man/internal_set_args_to_tune.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/make_classification_base_tbl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/make_regression_base_tbl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ef5a1e0

Please sign in to comment.