Skip to content

Commit

Permalink
Merge pull request #202 from spsanderson/development
Browse files Browse the repository at this point in the history
Add `.drop_na` to fast regress/class
  • Loading branch information
spsanderson authored Dec 31, 2023
2 parents 113343b + 4d7d741 commit eaa362c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 11 deletions.
10 changes: 8 additions & 2 deletions R/make-classification-fast.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#' split supported by `rsample`
#' @param .split_args The default is NULL, when NULL then the default parameters
#' of the split type will be executed for the rsample split type.
#' @param .drop_na The default is TRUE, which will drop all NA's from the data.
#'
#' @examples
#' library(recipes)
Expand All @@ -37,7 +38,7 @@
#' fct_tbl <- fast_classification(
#' .data = df,
#' .rec_obj = rec_obj,
#' .parsnip_eng = "glm"
#' .parsnip_eng = c("glm","earth")
#' )
#'
#' fct_tbl
Expand All @@ -54,7 +55,7 @@ NULL

fast_classification <- function(.data, .rec_obj, .parsnip_fns = "all",
.parsnip_eng = "all", .split_type = "initial_split",
.split_args = NULL){
.split_args = NULL, .drop_na = TRUE){

# Tidy Eval ----
call <- list(.parsnip_fns) |>
Expand Down Expand Up @@ -98,13 +99,18 @@ fast_classification <- function(.data, .rec_obj, .parsnip_fns = "all",
pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
)

if (.drop_na){
mod_pred_tbl <- tidyr::drop_na(mod_pred_tbl, pred_wflw) |>
dplyr::mutate(.model_id = dplyr::row_number())
}

# Return ----
class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
attr(mod_tbl, ".split_type") <- .split_type
attr(mod_tbl, ".split_args") <- .split_args
attr(mod_tbl, ".drop_na") <- .drop_na

return(mod_pred_tbl)
}
16 changes: 13 additions & 3 deletions R/make-regression-fast.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
#' split supported by `rsample`
#' @param .split_args The default is NULL, when NULL then the default parameters
#' of the split type will be executed for the rsample split type.
#' @param .drop_na The default is TRUE, which will drop all NA's from the data.
#'
#' @examples
#' library(recipes, quietly = TRUE)
#'
#' rec_obj <- recipe(mpg ~ ., data = mtcars)
#' frt_tbl <- fast_regression(mtcars, rec_obj, .parsnip_eng = c("lm","glm"),
#' .parsnip_fns = "linear_reg")
#' frt_tbl <- fast_regression(
#' mtcars,
#' rec_obj,
#' .parsnip_eng = c("lm","glm","gee"),
#' .parsnip_fns = "linear_reg"
#' )
#'
#' frt_tbl
#'
Expand All @@ -41,7 +46,7 @@ NULL

fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all",
.parsnip_eng = "all", .split_type = "initial_split",
.split_args = NULL){
.split_args = NULL, .drop_na = TRUE){

# Tidy Eval ----
call <- list(.parsnip_fns) |>
Expand Down Expand Up @@ -85,13 +90,18 @@ fast_regression <- function(.data, .rec_obj, .parsnip_fns = "all",
pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj)
)

if (.drop_na){
mod_pred_tbl <- tidyr::drop_na(mod_pred_tbl, pred_wflw) |>
dplyr::mutate(.model_id = dplyr::row_number())
}

# Return ----
class(mod_tbl) <- c("fst_reg_tbl", class(mod_tbl))
attr(mod_tbl, ".parsnip_engines") <- .parsnip_eng
attr(mod_tbl, ".parsnip_functions") <- .parsnip_fns
attr(mod_tbl, ".split_type") <- .split_type
attr(mod_tbl, ".split_args") <- .split_args
attr(mod_tbl, ".drop_na") <- .drop_na

return(mod_pred_tbl)
}
15 changes: 14 additions & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ frt_tbl <- fast_regression(
.parsnip_fns = "linear_reg"
)
frt_tbl$pred_wflw
extract_wf_preds(frt_tbl)
```

_Getting Regression Residuals_

Getting residuals is easy with `{tidyAML}`. Let's take a look.

```{r}
extract_regression_residuals(frt_tbl)
```

You can also pivot them into a long format making plotting easy with `ggplot2`.

```{r}
extract_regression_residuals(frt_tbl, .pivot_long = TRUE)
```
7 changes: 5 additions & 2 deletions man/fast_classification.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions man/fast_regression.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit eaa362c

Please sign in to comment.