Skip to content

Commit

Permalink
Merge pull request #342 from tlverse/haldensify_update
Browse files Browse the repository at this point in the history
  • Loading branch information
nhejazi authored Apr 11, 2021
2 parents 37a654d + 369fe9e commit 5cddc6c
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 40 deletions.
7 changes: 5 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ r_packages:
- covr
- sessioninfo
- data.table
- Rdpack
- delayed
- hal9001
- haldensify
#- hal9001
#- haldensify

r_github_packages:
- r-lib/covr
- r-lib/sessioninfo
- tlverse/origami
- tlverse/hal9001@devel
- nhejazi/haldensify

after_success:
- travis_wait 80 Rscript -e 'covr::codecov()'
Expand Down
13 changes: 8 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ Imports:
methods,
ggplot2,
digest,
Rdpack,
imputeMissings,
dplyr,
caret,
arm
caret
Suggests:
testthat,
rmarkdown,
Expand All @@ -58,6 +58,7 @@ Suggests:
reticulate,
rgl,
rJava,
arm,
bartMachine,
cvAUC,
e1071,
Expand All @@ -67,7 +68,7 @@ Suggests:
glmnet,
grf,
gbm,
hal9001 (>= 0.2.5),
hal9001 (>= 0.3.0),
h2o,
keras,
kerasR,
Expand All @@ -85,10 +86,11 @@ Suggests:
lightgbm,
dbarts,
gam (>= 1.15.0),
haldensify,
haldensify (>= 0.1.0),
mgcv,
hts
License: GPL-3
Language: en-US
URL: https://tlverse.org/sl3
BugReports: https://github.com/tlverse/sl3/issues
Encoding: UTF-8
Expand All @@ -97,5 +99,6 @@ LazyLoad: yes
VignetteBuilder:
knitr,
R.rsp
RoxygenNote: 7.1.1.9001
Roxygen: list(markdown = TRUE, old_usage = TRUE, r6 = FALSE)
RoxygenNote: 7.1.1.9001
RdMacros: Rdpack
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ import(delayed)
import(ggplot2)
importFrom(BBmisc,requirePackages)
importFrom(R6,R6Class)
importFrom(arm,bayesglm)
importFrom(assertthat,assert_that)
importFrom(assertthat,is.count)
importFrom(assertthat,is.flag)
Expand Down
1 change: 0 additions & 1 deletion R/Lrnr_bayesglm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom arm bayesglm
#'
#' @export
#'
Expand Down
39 changes: 33 additions & 6 deletions R/Lrnr_haldensify.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#' each bin has the same number of observations, use "equal_mass" (based on
#' \code{\link[ggplot2]{cut_number}}).
#' }
#' \item{\code{n_bins = c(5, 10)}}{Only used if \code{type} is set to
#' \item{\code{n_bins = c(3, 5)}}{Only used if \code{type} is set to
#' \code{"equal_range"} or \code{"equal_mass"}. This \code{numeric} value
#' indicates the number of bins that the support of the outcome variable is
#' to be divided into.
Expand All @@ -32,6 +32,15 @@
#' sequence of values of the regulariztion parameter of the Lasso regression,
#' to be passed to to \code{\link[hal9001]{fit_hal}}.
#' }
#' \item{\code{trim_dens = 1/sqrt(n)}}{A \code{numeric} giving the minimum
#' allowed value of the resultant density predictions. Any predicted
#' density values below this tolerance threshold are set to the indicated
#' minimum. The default is to use the inverse of the square root of the
#' sample size of the prediction set, i.e., 1/sqrt(n); another notable
#' choice is 1/sqrt(n)/log(n). If there are observations in the prediction
#' set with values of \code{new_A} outside of the support of the training
#' set, their predictions are similarly truncated.
#' }
#' \item{\code{...}}{ Other parameters passed directly to
#' \code{\link[haldensify]{haldensify}}. See its documentation for details.
#' }
Expand All @@ -41,9 +50,10 @@ Lrnr_haldensify <- R6Class(
classname = "Lrnr_haldensify", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
public = list(
initialize = function(grid_type = c("equal_range", "equal_mass"),
n_bins = c(5, 10),
initialize = function(grid_type = "equal_range",
n_bins = c(3, 5),
lambda_seq = exp(seq(-1, -13, length = 1000L)),
trim_dens = NULL,
...) {
params <- args_to_list()
super$initialize(params = params, ...)
Expand Down Expand Up @@ -77,25 +87,42 @@ Lrnr_haldensify <- R6Class(
args$family <- outcome_type$glm_family(return_object = TRUE)$family
}

# extract input data
args$W <- as.matrix(task$X)
args$A <- as.numeric(outcome_type$format(task$Y))
args$use_future <- FALSE

# handle weights
if (task$has_node("weights")) {
args$wts <- task$weights
}

# extract offset
if (task$has_node("offset")) {
args$offset <- task$offset
}

fit_object <- call_with_args(haldensify::haldensify, args)
# fit haldensify conditional density estimator
fit_object <- call_with_args(
haldensify::haldensify, args,
other_valid = c("max_degree", "smoothness_orders", "num_knots",
"adaptive_smoothing", "reduce_basis", "use_min"),
ignore = c("cv_select", "weights", "family", "fit_type", "trim_dens")
)
return(fit_object)
},
.predict = function(task = NULL) {
# set density trimming to haldensify::predict default if NULL
if (is.null(self$params[["trim_dens"]])) {
trim_dens <- 1 / sqrt(task$nrow)
} else {
trim_dens <- self$params[["trim_dens"]]
}

# predict density
predictions <- predict(self$fit_object,
new_A = as.numeric(task$Y),
new_W = as.matrix(task$X)
new_W = as.matrix(task$X),
trim_dens = trim_dens
)
return(predictions)
},
Expand Down
4 changes: 3 additions & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ environment:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true

build_script:
- travis-tool.sh install_github r-lib/covr r-lib/sessioninfo
- travis-tool.sh install_github tlverse/origami tlverse/hal9001@devel
- travis-tool.sh install_github nhejazi/haldensify
- travis-tool.sh install_deps
- travis-tool.sh install_github r-lib/covr r-lib/sessioninfo tlverse/origami

test_script:
- travis-tool.sh run_tests
Expand Down
2 changes: 1 addition & 1 deletion man/Lrnr_haldensify.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions tests/testthat/test-cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ context("test-cv.R -- Cross-validation fold handling")

library(origami)
options(java.parameters = "-Xmx2500m")

data(cpp_imputed)
covars <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn")
outcome <- "haz"
Expand Down Expand Up @@ -104,7 +103,8 @@ test_loocv_learner <- function(learner, loocv_task, ...) {

# test learner chaining
chained_task <- fit_obj$chain()
test_that("Chaining returns a task", expect_true(is(chained_task, "sl3_Task")))
test_that("Chaining returns a task", expect_true(is(chained_task,
"sl3_Task")))
test_that("Chaining returns the correct number of rows", expect_equal(
nrow(chained_task$X),
nrow(loocv_task$X)
Expand All @@ -114,7 +114,8 @@ test_loocv_learner <- function(learner, loocv_task, ...) {
preds_full <- fit_obj$predict_fold(loocv_task, "full")
preds_valid <- fit_obj$predict_fold(loocv_task, "validation")
validation_task <- validation(loocv_task, fold = loocv_task$folds[[1]])
validation_preds <- fit_obj$fit_object$fold_fits[[1]]$predict(validation_task)
validation_preds <-
fit_obj$fit_object$fold_fits[[1]]$predict(validation_task)
test_that("Learners do not error under LOOCV", {
expect_false(any(is.na(preds_valid)))
expect_false(any(is.na(preds_fold1)))
Expand All @@ -125,7 +126,9 @@ test_loocv_learner <- function(learner, loocv_task, ...) {

# make task with LOOCV
d <- cpp_imputed[1:50, ]
expect_warning(loocv_folds <- make_folds(n = d, fold_fun = folds_vfold, V = 50))
expect_warning({
loocv_folds <- make_folds(n = d, fold_fun = folds_vfold, V = 50)
})
loocv_task <- sl3_Task$new(d, covars, outcome, folds = loocv_folds)

# get learners
Expand All @@ -138,5 +141,10 @@ wrap <- sl3::sl3_list_learners("wrapper")
h2o <- sl3::sl3_list_learners("h2o")
learners <- cont_learners[-which(cont_learners %in% c(ts, screen, wrap, h2o))]

# remove LightGBM on Windows
if (Sys.info()["sysname"] == "Windows") {
learners <- learners[!(learners == "Lrnr_lightgbm")]
}

# test all relevant learners
lapply(learners, test_loocv_learner, loocv_task)
22 changes: 15 additions & 7 deletions tests/testthat/test-density-pooled_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@ data <- data.table(x = x, y = y)
task <- sl3_Task$new(data, covariates = c("x"), outcome = "y")

# instantiate learners
hal <- Lrnr_hal9001$new(lambda_seq = exp(seq(-1, -13, length = 100)))
hal <- Lrnr_hal9001$new(
lambda = exp(seq(-1, -13, length = 100)),
max_degree = 6,
smoothness_orders = 0
)
haldensify <- Lrnr_haldensify$new(
grid_type = "equal_mass",
n_bins = 10,
lambda_seq = exp(seq(-1, -13, length = 100))
grid_type = "equal_range",
n_bins = 5,
lambda_seq = exp(seq(-1, -13, length = 100)),
max_degree = 6,
smoothness_orders = 0,
trim_dens = 0
)
hazard_learner <- Lrnr_pooled_hazards$new(hal)
density_learner <- Lrnr_density_discretize$new(hazard_learner,
type = "equal_mass",
n_bins = 10
density_learner <- Lrnr_density_discretize$new(
hazard_learner,
type = "equal_range",
n_bins = 5
)

# fit discrete density model to pooled hazards data
Expand Down
18 changes: 11 additions & 7 deletions tests/testthat/test-haldensify.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ if (FALSE) {
getwd()
library("devtools")
document()
load_all("./") # load all R files in /R and datasets in /data. Ignores NAMESPACE:
load_all("./") # load all R files in /R and datasets in /data.
# Ignores NAMESPACE:
devtools::check() # runs full check
setwd("..")
install("sl3",
Expand Down Expand Up @@ -34,9 +35,10 @@ task <- cpp_imputed %>%
)

hal_dens <- Lrnr_haldensify$new(
grid_type = "equal_mass",
n_bins = 10,
lambda_seq = exp(seq(-1, -13, length = 100))
grid_type = "equal_range",
n_bins = c(3, 5),
lambda_seq = exp(seq(-1, -13, length = 100)),
max_degree = 6, smoothness_orders = 0
)

test_that("Lrnr_haldensify produces predictions identical to haldensify", {
Expand All @@ -48,11 +50,13 @@ test_that("Lrnr_haldensify produces predictions identical to haldensify", {
haldensify_fit <- haldensify::haldensify(
A = as.numeric(task$Y),
W = as.matrix(task$X),
grid_type = "equal_mass",
n_bins = 10,
grid_type = "equal_range",
n_bins = c(3, 5),
lambda_seq = exp(seq(-1, -13,
length = 100
))
)),
max_degree = 6,
smoothness_orders = 0
)
haldensify_preds <- predict(haldensify_fit,
new_A = as.numeric(task$Y),
Expand Down
21 changes: 16 additions & 5 deletions tests/testthat/test-lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,28 @@ test_learner <- function(learner, task, ...) {
print(sprintf("Testing Learner: %s", learner_obj$name))
# test learner training
fit_obj <- learner_obj$train(task)
test_that("Learner can be trained on data", expect_true(fit_obj$is_trained))
test_that("Learner can be trained on data", {
skip_on_os("windows")
expect_true(fit_obj$is_trained)
})

# test learner prediction
train_preds <- fit_obj$predict()
test_that("Learner can generate training set predictions", expect_equal(
sl3:::safe_dim(train_preds)[1],
length(task$Y)
))
test_that("Learner can generate training set predictions", {
skip_on_os("windows")
expect_equal(
sl3:::safe_dim(train_preds)[1], length(task$Y)
)
})

# test learner chaining
chained_task <- fit_obj$chain()
test_that("Chaining returns a task", {
skip_on_os("windows")
expect_true(is(chained_task, "sl3_Task"))
})
test_that("Chaining returns the correct number of rows", {
skip_on_os("windows")
expect_equal(nrow(chained_task$X), nrow(task$X))
})
}
Expand All @@ -39,6 +46,7 @@ options(sl3.verbose = TRUE)
test_learner(Lrnr_lightgbm, task)

test_that("Lrnr_lightgbm predictions match lightgbm's: continuous outcome", {
skip_on_os("windows")
## instantiate Lrnr_lightgbm, train on task, and predict on task
set.seed(73964)
lrnr_lightgbm <- Lrnr_lightgbm$new()
Expand All @@ -63,6 +71,7 @@ test_that("Lrnr_lightgbm predictions match lightgbm's: continuous outcome", {
})

test_that("Lrnr_lightgbm predictions match lightgbm's: binary outcome", {
skip_on_os("windows")
## create task with binary outcome
covars <- c("bmi", "haz", "mage", "sexn")
outcome <- "smoked"
Expand Down Expand Up @@ -93,6 +102,7 @@ test_that("Lrnr_lightgbm predictions match lightgbm's: binary outcome", {
})

test_that("Lrnr_lightgbm predictions match lightgbm's: categorical outcome", {
skip_on_os("windows")
## create task with binary outcome
covars <- c("bmi", "haz", "mage", "sexn")
outcome <- "parity"
Expand Down Expand Up @@ -125,6 +135,7 @@ test_that("Lrnr_lightgbm predictions match lightgbm's: categorical outcome", {
})

test_that("Cursory test of Lrnr_lightgbm with weights", {
skip_on_os("windows")
## create task, continuous outcome with observation-level weights
covars <- c("bmi", "parity", "mage", "sexn")
outcome <- "haz"
Expand Down

0 comments on commit 5cddc6c

Please sign in to comment.