Skip to content

Commit

Permalink
initial support for fit size reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyrcoyle committed Oct 1, 2021
1 parent 20834ae commit 0de5af7
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 2 deletions.
51 changes: 51 additions & 0 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ Lrnr_base <- R6Class(
}
new_object <- self$clone() # copy parameters, and whatever else
new_object$set_train(fit_object, task)
if(getOption("sl3.reduce_fit")){
new_object$reduce_fit(check_preds = FALSE)
}
return(new_object)
},
set_train = function(fit_object, training_task) {
Expand Down Expand Up @@ -335,6 +338,53 @@ Lrnr_base <- R6Class(
} else {
return(task)
}
},
reduce_fit = function(fit_object = NULL, check_preds = TRUE, set_train = TRUE){
if(is.null(fit_object)){
fit_object <- self$fit_object
}
if(check_preds){
preds_full <- self$predict(task)
}


# try reducing the size
size_full <- true_obj_size(fit_object)

# see what's taking up the space
# element_sizes <- sapply(fo, true_obj_size)
# ranked <- sort(element_sizes/size_full, decreasing = TRUE)

# by default, drop out call
# within(fit_object, rm(private$.fit_can_remove))
keep <- setdiff(names(fit_object), private$.fit_can_remove)

# gotta preserve the attributes (not sure why they're getting dropped)
attrs <- attributes(fit_object)
attrs$names <- keep
reduced <- fit_object[keep]
attributes(reduced) <- attrs
fit_object <- reduced
size_reduced <- true_obj_size(fit_object)
reduction_percent <- 1 - size_reduced/size_full

if(getOption("sl3.verbose")){
message(sprintf("Fit object size reduced %0.0f%%", 100*reduction_percent))
}


if(set_train){
self$set_train(fit_object, self$training_task)
}


# verify prediction still works
if(check_preds){
preds_reduced <- self$predict(task)
assert_that(all.equal(preds_full, preds_reduced))
}

return(fit_object)
}
),
active = list(
Expand Down Expand Up @@ -399,6 +449,7 @@ Lrnr_base <- R6Class(
.required_packages = NULL,
.properties = list(),
.custom_chain = NULL,
.fit_can_remove = c("call"),
.train_sublearners = function(task) {
# train sublearners here
return(NULL)
Expand Down
1 change: 1 addition & 0 deletions R/Lrnr_glm_fast.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Lrnr_glm_fast <- R6Class(
}
return(predictions)
},
.fit_can_remove = c("XTX"),
.required_packages = c("speedglm")
)
)
2 changes: 1 addition & 1 deletion R/Lrnr_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Lrnr_hal9001 <- R6Class(
}
return(predictions)
},

.fit_can_remove = c("lasso_fit","x_basis"),
.required_packages = c("hal9001", "glmnet")
)
)
1 change: 1 addition & 0 deletions R/Lrnr_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Lrnr_xgboost <- R6Class(

return(predictions)
},
.fit_can_remove = c("raw", "call"),
.required_packages = c("xgboost")
)
)
10 changes: 10 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ true_obj_size <- function(obj) {
length(serialize(obj, NULL))
}

#' @keywords internal
check_fit_sizes <- function(fit){
fo <- fit$fit_object
# see what's taking up the space
element_sizes <- sapply(fo, true_obj_size)
ranked <- sort(element_sizes/sum(element_sizes), decreasing = TRUE)

return(ranked)
}

################################################################################

#' Drop components from learner fits
Expand Down
3 changes: 2 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ sl3Options <- function(o, value) {
"sl3.pcontinuous" = 0.05,
"sl3.max_p_missing" = 0.5,
"sl3.transform.offset" = TRUE,
"sl3.enable.future" = TRUE
"sl3.enable.future" = TRUE,
"sl3.reduce_fit" = FALSE
)
# for (i in setdiff(names(opts),names(options()))) {
# browser()
Expand Down
113 changes: 113 additions & 0 deletions tests/testthat/test-nested-sl.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
library(testthat)
context("test_sl.R -- Basic Lrnr_sl functionality")

options(sl3.verbose = TRUE)
library(sl3)
library(origami)
library(SuperLearner)

data(cpp_imputed)
covars <- c(
"apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs",
"sexn"
)
outcome <- "haz"
task <- sl3_Task$new(data.table::copy(cpp_imputed),
covariates = covars, outcome = outcome
)
task2 <- sl3_Task$new(data.table::copy(cpp_imputed),
covariates = covars, outcome = outcome
)

glm_learner <- Lrnr_glm$new()
glmnet_learner <- Lrnr_pkg_SuperLearner$new("SL.glmnet")
subset_apgar <- Lrnr_subset_covariates$new(covariates = c("apgar1", "apgar5"))
learners <- list(glm_learner, glmnet_learner, subset_apgar)
sl1 <- make_learner(Lrnr_sl, learners, glm_learner)

sl1_fit <- sl1$train(task)
test_that("Coefficients can extracted from sl fits", {
expect_true(!is.null(coef(sl1_fit)))
})
glm_fit <- sl1_fit$learner_fits$Lrnr_glm_TRUE
test_that("Library fits can extracted from sl fits", {
expect_true(inherits(glm_fit, "Lrnr_glm"))
})


sl1_risk <- sl1_fit$cv_risk(loss_squared_error)

expected_learners <- c(
"Lrnr_glm_TRUE", "Lrnr_pkg_SuperLearner_SL.glmnet",
"Lrnr_subset_covariates_c(\"apgar1\", \"apgar5\")_apgar1",
"Lrnr_subset_covariates_c(\"apgar1\", \"apgar5\")_apgar5"
)
test_that("sl1_fit is based on the right learners", {
expect_equal(
sl1_fit$fit_object$cv_meta_task$nodes$covariates,
expected_learners
)
})

stack <- make_learner(Stack, learners)
sl2 <- make_learner(Lrnr_sl, stack, glm_learner)

sl2_fit <- sl2$train(task)
sl2_risk <- sl2_fit$cv_risk(loss_squared_error)

test_that("Lrnr_sl can accept a pre-made stack", {
expect_equal(sl1_risk$mean_risk, sl2_risk$mean_risk, tolerance = 1e-2)
})

sl_nnls <- Lrnr_sl$new(
learners = list(glm_learner, glmnet_learner),
metalearner = sl3::Lrnr_nnls$new()
)
sl_nnls_fit <- sl_nnls$train(task)

sl1_small <- Lrnr_sl$new(
learners = list(
glm_learner, glmnet_learner,
subset_apgar
),
metalearner = glm_learner, keep_extra = FALSE
)
sl1_small_fit <- sl1_small$train(task)
expect_lt(length(sl1_small_fit$fit_object), length(sl1_fit$fit_object))
preds <- sl1_small_fit$predict(task)
preds_fold <- sl1_small_fit$predict_fold(task, "full")
test_that("predict_fold(task,'full') works if keep_extra=FALSE", {
expect_equal(preds, preds_fold)
})

# sl of a pipeline from https://github.com/tlverse/sl3/issues/81
data(cpp)
cpp <- cpp[!is.na(cpp[, "haz"]), ]
covars <- c(
"apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs",
"sexn"
)
cpp[is.na(cpp)] <- 0
outcome <- "haz"
task <- sl3_Task$new(cpp, covariates = covars, outcome = outcome)
make_inter <- Lrnr_define_interactions$new(
interactions = list(c("apgar1", "parity"), c("apgar5", "parity"))
)

glm_learner <- Lrnr_glm$new()
glmnet_learner <- Lrnr_glmnet$new(nlambda = 5)
learners <- Stack$new(glm_learner, glmnet_learner)
pipe <- Pipeline$new(make_inter, learners)
sl1 <- make_learner(Lrnr_sl, pipe, metalearner = Lrnr_solnp$new())
fit <- sl1$train(task)
print(fit)

# Metalearner does not return coefficients.
glm_learner <- Lrnr_glm$new()
glmnet_learner <- Lrnr_glmnet$new(nlambda = 5)
learners <- Stack$new(glm_learner, glmnet_learner)
# Selecting a metalearner that does not provide coefficients.
ranger_learner <- Lrnr_ranger$new(num.trees = 5L)
sl1 <- make_learner(Lrnr_sl, learners, ranger_learner)
sl1_fit <- sl1$train(task)
print(sl1_fit)
42 changes: 42 additions & 0 deletions tests/testthat/test-reduce_fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

set.seed(1234)

# TODO: maybe check storage at different n to get rate
n <- 1e3
p <- 100
# these two define the DGP (randomly)
p_X <- runif(p, 0.2, 0.8)
beta <- rnorm(p)

# simulate from the DGP
X <- sapply(p_X, function(p_Xi)rbinom(n, 1, p_Xi))
p_Yx <- plogis(X%*%beta)
Y <- rbinom(n, 1, p_Yx)
data <- data.table(X,Y)

# generate the sl3 task and learner
outcome <- "Y"
covariates <- setdiff(names(data),outcome)
task <- make_sl3_Task(data, covariates, outcome)

options(sl3.verbose = TRUE)
options(sl3.reduce_fit = TRUE)
test_reduce_fit <- function(learner){
fit <- learner$train(task)
print(sl3:::check_fit_sizes(fit))
if(!getOption("sl3.reduce_fit")){
# if we aren't automatically reducing, do it manually
fit_object <- fit$reduce_fit()
}

still_present <- intersect(names(fit$fit_object),
fit$.__enclos_env__$private$.fit_can_remove)

expect_equal(length(still_present),0)
}

test_reduce_fit(make_learner(Lrnr_glmnet))
test_reduce_fit(make_learner(Lrnr_ranger))
test_reduce_fit(make_learner(Lrnr_glm_fast))
test_reduce_fit(make_learner(Lrnr_xgboost))
test_reduce_fit(make_learner(Lrnr_hal9001))
28 changes: 28 additions & 0 deletions tests/testthat/test-xgboost-weights.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
context("test-xgboost.R -- General testing for Xgboost")

library(sl3)
library(xgboost)
library(data.table)
library(microbenchmark)

# define test dataset
n <- 1e4
p <- 10
W <- matrix(runif(n*p),nrow=n)
y <- sin(W[,1]) + W[,2]^2 + W[,1]*W[,2] + rnorm(n)
dt <- data.table(W,y)
task <- make_sl3_Task(dt, covariates=paste0("V",1:p), outcome="y")
weights <- rbinom(n,1,0.1)
new_columns <- task$add_columns(data.table(weights=weights))
task_weights <- task$next_in_chain(column_names = new_columns, weights="weights")
task_subset <- task[which(weights==1)]
xgb <- make_learner(Lrnr_xgboost)
glm_fast <- make_learner(Lrnr_glm_fast)

microbenchmark(
xgb_full <- glm_fast$train(task),
xgb_weights <- glm_fast$train(task_weights),
xgb_subset <- glm_fast$train(task_subset),
times=100)

cor(cbind(xgb_full$predict(task_subset),xgb_weights$predict(task_subset),xgb_subset$predict(task_subset)))
65 changes: 65 additions & 0 deletions tests/testthat/test_joint_density.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
gen_data <- function(n=1e3){
x1 <- rnorm(n)
x2 <- 3*x1+rnorm(n)
x3 <- rbinom(n,1, plogis(x1+2*x2))
x4 <- runif(n)
x5 <- rbinom(n,1,plogis(x1+3*x3))
x <- cbind(x1,x2,x3,x4,x5)

return(x)
}

ll_p0 <- function(x){
f1 <- dnorm(x[,1])
f2 <- dnorm(x[,2]-3*x[,1])
f3 <- x[,3]*plogis(x[,1] + 2*x[,2])+(1-x[,3])*(1-plogis(x[,1] + 2*x[,2]))
f4 <- dunif(x[,4])
f5 <- x[,5]*plogis(x[,1] + 3*x[,3])+(1-x[,5])*(1-plogis(x[,1] + 3*x[,3]))
f <- cbind(f1,f2,f3,f4,f5)
rowSums(log(f))
}

library(hal9001)


# todo: need an object to represent a mapping from a raw variable to a set of basis functions
discretize <- function(x){
discretize_col <-function(col){
col_dat <- x[,col]
unique_vals <- unique(col_dat)
# TODO: handle disparate variable types better
if(length(unique_vals)<5){
binned <- col_dat
} else{
# TODO: think about how much we want to "expand" each density
breaks <- seq(min(col_dat)-1e-5,max(col_dat)+1e-5,length=5)
binned <- cut(col_dat, breaks, label = FALSE)
}
return(binned)

}

discretized <- sapply(1:ncol(x), discretize_col)

# use histogram density estimators
# still pretty sure this reduces to the usual hal basis matrix type
# outcome is pnz for that basis
enu
}

make_grid <- function(x){
make_grid_col <-function(col){
col_dat <- x[,col]
unique_vals <- unique(col_dat)
return(grid)

}
grids <- lapply(1:ncol(x),make_grid_col)
expanded <- do.call(expand.grid,grids)
}

grid <- make_grid(discretized)
cat_vec <- grid[10,]
apply(grid,1,function(cat_vec){
discretized==cat_vec
}

0 comments on commit 0de5af7

Please sign in to comment.