Skip to content

Commit

Permalink
Merge pull request #1392 from tidymodels/sparse-dummy
Browse files Browse the repository at this point in the history
add `sparse` argument to `step_dummy()`
  • Loading branch information
EmilHvitfeldt authored Nov 14, 2024
2 parents ee360b0 + 6031f7f commit 59345e1
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 34 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

* All steps and checks now require arguments `trained`, `skip`, `role`, and `id` at all times.

* `step_dummy()` gained `sparse` argument. When set to `TRUE`, `step_dummy()` will produce sparse vectors. (#1392)

# recipes 1.1.0

## Improvements
Expand Down
92 changes: 59 additions & 33 deletions R/dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#' @param levels A list that contains the information needed to create dummy
#' variables for each variable contained in `terms`. This is `NULL` until the
#' step is trained by [prep()].
#' @param sparse A logical. Should the columns produced be sparse vectors.
#' Sparsity is only supported for `"contr.treatment"` contrasts. Defaults to
#' `FALSE`.
#' @template step-return
#' @family dummy variable and encoding steps
#' @seealso [dummy_names()]
Expand Down Expand Up @@ -60,7 +63,8 @@
#' this step.
#'
#' Also, there are a number of contrast methods that return fractional values.
#' The columns returned by this step are doubles (not integers).
#' The columns returned by this step are doubles (not integers) when
#' `sparse = FALSE`. The columns returned when `sparse = TRUE` are integers.
#'
#' The [package vignette for dummy variables](https://recipes.tidymodels.org/articles/Dummies.html)
#' and interactions has more information.
Expand Down Expand Up @@ -121,6 +125,7 @@ step_dummy <-
preserve = deprecated(),
naming = dummy_names,
levels = NULL,
sparse = FALSE,
keep_original_cols = FALSE,
skip = FALSE,
id = rand_id("dummy")) {
Expand All @@ -143,6 +148,7 @@ step_dummy <-
preserve = keep_original_cols,
naming = naming,
levels = levels,
sparse = sparse,
keep_original_cols = keep_original_cols,
skip = skip,
id = id
Expand All @@ -151,7 +157,7 @@ step_dummy <-
}

step_dummy_new <-
function(terms, role, trained, one_hot, preserve, naming, levels,
function(terms, role, trained, one_hot, preserve, naming, levels, sparse,
keep_original_cols, skip, id) {
step(
subclass = "dummy",
Expand All @@ -162,6 +168,7 @@ step_dummy_new <-
preserve = preserve,
naming = naming,
levels = levels,
sparse = sparse,
keep_original_cols = keep_original_cols,
skip = skip,
id = id
Expand All @@ -174,6 +181,7 @@ prep.step_dummy <- function(x, training, info = NULL, ...) {
check_type(training[, col_names], types = c("factor", "ordered"))
check_bool(x$one_hot, arg = "one_hot")
check_function(x$naming, arg = "naming", allow_empty = FALSE)
check_bool(x$sparse, arg = "sparse")

if (length(col_names) > 0) {
## I hate doing this but currently we are going to have
Expand Down Expand Up @@ -218,6 +226,7 @@ prep.step_dummy <- function(x, training, info = NULL, ...) {
preserve = x$preserve,
naming = x$naming,
levels = levels,
sparse = x$sparse,
keep_original_cols = get_keep_original_cols(x),
skip = x$skip,
id = x$id
Expand Down Expand Up @@ -285,43 +294,60 @@ bake.step_dummy <- function(object, new_data, ...) {
col_name,
step = "step_dummy"
)

new_data[, col_name] <- factor(
new_data[[col_name]],
levels = levels_values,
ordered = is_ordered
)

new_data[, col_name] <-
factor(
new_data[[col_name]],
levels = levels_values,
ordered = is_ordered
)
if (object$sparse) {
current_contrast <- getOption("contrasts")[is_ordered + 1]
if (current_contrast != "contr.treatment") {
cli::cli_abort(
"When {.code sparse = TRUE}, only {.val contr.treatment} contrasts are
supported, not {.val {current_contrast}}."
)
}

indicators <-
model.frame(
rlang::new_formula(lhs = NULL, rhs = rlang::sym(col_name)),
data = new_data[, col_name],
xlev = levels_values,
na.action = na.pass
indicators <- sparsevctrs::sparse_dummy(
x = new_data[[col_name]],
one_hot = object$one_hot
)

indicators <- tryCatch(
model.matrix(object = levels, data = indicators),
error = function(cnd) {
if (grepl("(vector memory|cannot allocate)", cnd$message)) {
n_levels <- length(attr(levels, "values"))
cli::cli_abort(
"{.var {col_name}} contains too many levels ({n_levels}), \\
which would result in a data.frame too large to fit in memory.",
call = NULL
)
indicators <- tibble::new_tibble(indicators)
used_lvl <- colnames(indicators)
} else {
indicators <-
model.frame(
rlang::new_formula(lhs = NULL, rhs = rlang::sym(col_name)),
data = new_data[, col_name],
xlev = levels_values,
na.action = na.pass
)

indicators <- tryCatch(
model.matrix(object = levels, data = indicators),
error = function(cnd) {
if (grepl("(vector memory|cannot allocate)", cnd$message)) {
n_levels <- length(attr(levels, "values"))
cli::cli_abort(
"{.var {col_name}} contains too many levels ({n_levels}), \\
which would result in a data.frame too large to fit in memory.",
call = NULL
)
}
stop(cnd)
}
stop(cnd)
)

if (!object$one_hot) {
indicators <- indicators[, colnames(indicators) != "(Intercept)", drop = FALSE]
}
)

if (!object$one_hot) {
indicators <- indicators[, colnames(indicators) != "(Intercept)", drop = FALSE]

## use backticks for nonstandard factor levels here
used_lvl <- gsub(paste0("^\\`?", col_name, "\\`?"), "", colnames(indicators))
}

## use backticks for nonstandard factor levels here
used_lvl <- gsub(paste0("^\\`?", col_name, "\\`?"), "", colnames(indicators))

new_names <- object$naming(col_name, used_lvl, is_ordered)
colnames(indicators) <- new_names
indicators <- check_name(indicators, new_data, object, new_names)
Expand Down
8 changes: 7 additions & 1 deletion man/step_dummy.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/_snaps/dummy.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@
Caused by error in `bake()`:
! Only one factor level in `x`: "only-level".

# sparse = TRUE errors on unsupported contrasts

Code
recipe(~., data = tibble(x = letters)) %>% step_dummy(x, sparse = TRUE) %>%
prep()
Condition
Error in `step_dummy()`:
Caused by error in `bake()`:
! When `sparse = TRUE`, only "contr.treatment" contrasts are supported, not "contr.helmert".

# bake method errors when needed non-standard role columns are missing

Code
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,34 @@ test_that("throws an informative error for single level", {
)
})

test_that("sparse = TRUE works", {
rec <- recipe(~ ., data = tibble(x = c(NA, letters)))

suppressWarnings({
dense <- rec %>% step_dummy(x, sparse = FALSE) %>% prep() %>% bake(NULL)
dense <- purrr::map(dense, as.integer) %>% tibble::new_tibble()
sparse <- rec %>% step_dummy(x, sparse = TRUE) %>% prep() %>% bake(NULL)
})

expect_identical(dense, sparse)

expect_false(any(vapply(dense, sparsevctrs::is_sparse_vector, logical(1))))
expect_true(all(vapply(sparse, sparsevctrs::is_sparse_vector, logical(1))))
})

test_that("sparse = TRUE errors on unsupported contrasts", {
go_helmert <- getOption("contrasts")
go_helmert["unordered"] <- "contr.helmert"
withr::local_options(contrasts = go_helmert)

expect_snapshot(
error = TRUE,
recipe(~ ., data = tibble(x = letters)) %>%
step_dummy(x, sparse = TRUE) %>%
prep()
)
})

# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
Expand Down

0 comments on commit 59345e1

Please sign in to comment.