Skip to content

Commit

Permalink
fix same bug in step_bs()
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Jun 7, 2024
1 parent ea44fda commit ced9fbf
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

* `recipe()` will now show better error when columns are misspelled in formula (#1283).

* Fixed bug in `step_ns()` where `knots` field in `options` argument wasn't correctly used. (#1297)
* Fixed bug in `step_ns()` and `step_bs()` where `knots` field in `options` argument wasn't correctly used. (#1297)

* `add_role()` now errors if a column would simultaneously have roles `"outcome"` and `"predictor"`. (#935)

Expand Down
9 changes: 7 additions & 2 deletions R/bs.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' @details `step_bs` can create new features from a single variable
#' that enable fitting routines to model this variable in a
#' nonlinear manner. The extent of the possible nonlinearity is
#' determined by the `df`, `degree`, or `knot` arguments of
#' determined by the `df`, `degree`, or `knots` arguments of
#' [splines::bs()]. The original variables are removed
#' from the data and new columns are added. The naming convention
#' for the new variables is `varname_bs_1` and so on.
Expand Down Expand Up @@ -121,12 +121,17 @@ bs_statistics <- function(x, args) {
ok <- !is.na(x) & x >= boundary[1L] & x <= boundary[2L]
knots <- unname(quantile(x[ok], seq_len(num_knots) / (num_knots + 1L)))
} else {
knots <- numeric()
if (is.null(args$knots)) {
knots <- numeric()
} else {
knots <- args$knots
}
}

# Only construct the data necessary for splines_predict
out <- matrix(NA, ncol = degree + length(knots) + intercept, nrow = 1L)
class(out) <- c("bs", "basis", "matrix")
attr(out, "degree") <- 3L
attr(out, "knots") <- knots
attr(out, "Boundary.knots") <- boundary
attr(out, "intercept") <- intercept
Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/test-bs.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,30 @@ test_that("correct basis functions", {
expect_equal(hydrogen_bs_te_res, hydrogen_bs_te_exp)
})

test_that("options(knots) works correctly (#1297)", {
exmaple_data <- tibble(x = seq(-2, 2, 0.01))

rec_res <- recipe(~., data = exmaple_data) %>%
step_bs(x, options = list(knots = seq(-1, 1, 0.125),
Boundary.knots = c(-2.5, 2.5))) %>%
prep() %>%
bake(new_data = NULL)

mm_res <- model.matrix(
~ splines::bs(
x,
knots = seq(-1, 1, 0.125),
Boundary.knots = c(-2.5, 2.5)
) - 1,
data = exmaple_data
)

attr(mm_res, "assign") <- NULL
mm_res <- setNames(as_tibble(mm_res), names(rec_res))

expect_identical(rec_res, mm_res)
})

test_that("check_name() is used", {
dat <- mtcars
dat$mpg_bs_1 <- dat$mpg
Expand Down

0 comments on commit ced9fbf

Please sign in to comment.