Skip to content

Commit

Permalink
initial support for fit size reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyrcoyle committed Oct 1, 2021
1 parent 20834ae commit 86d610f
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 5 deletions.
51 changes: 51 additions & 0 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ Lrnr_base <- R6Class(
}
new_object <- self$clone() # copy parameters, and whatever else
new_object$set_train(fit_object, task)
if (getOption("sl3.reduce_fit")) {
new_object$reduce_fit(check_preds = FALSE)
}
return(new_object)
},
set_train = function(fit_object, training_task) {
Expand Down Expand Up @@ -335,6 +338,53 @@ Lrnr_base <- R6Class(
} else {
return(task)
}
},
reduce_fit = function(fit_object = NULL, check_preds = TRUE, set_train = TRUE) {
if (is.null(fit_object)) {
fit_object <- self$fit_object
}
if (check_preds) {
preds_full <- self$predict(task)
}


# try reducing the size
size_full <- true_obj_size(fit_object)

# see what's taking up the space
# element_sizes <- sapply(fo, true_obj_size)
# ranked <- sort(element_sizes/size_full, decreasing = TRUE)

# by default, drop out call
# within(fit_object, rm(private$.fit_can_remove))
keep <- setdiff(names(fit_object), private$.fit_can_remove)

# gotta preserve the attributes (not sure why they're getting dropped)
attrs <- attributes(fit_object)
attrs$names <- keep
reduced <- fit_object[keep]
attributes(reduced) <- attrs
fit_object <- reduced
size_reduced <- true_obj_size(fit_object)
reduction_percent <- 1 - size_reduced / size_full

if (getOption("sl3.verbose")) {
message(sprintf("Fit object size reduced %0.0f%%", 100 * reduction_percent))
}


if (set_train) {
self$set_train(fit_object, self$training_task)
}


# verify prediction still works
if (check_preds) {
preds_reduced <- self$predict(task)
assert_that(all.equal(preds_full, preds_reduced))
}

return(fit_object)
}
),
active = list(
Expand Down Expand Up @@ -399,6 +449,7 @@ Lrnr_base <- R6Class(
.required_packages = NULL,
.properties = list(),
.custom_chain = NULL,
.fit_can_remove = c("call"),
.train_sublearners = function(task) {
# train sublearners here
return(NULL)
Expand Down
1 change: 1 addition & 0 deletions R/Lrnr_glm_fast.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Lrnr_glm_fast <- R6Class(
}
return(predictions)
},
.fit_can_remove = c("XTX"),
.required_packages = c("speedglm")
)
)
3 changes: 1 addition & 2 deletions R/Lrnr_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ Lrnr_hal9001 <- R6Class(

return(fit_object)
},

.predict = function(task = NULL) {
predictions <- stats::predict(
self$fit_object,
Expand All @@ -111,7 +110,7 @@ Lrnr_hal9001 <- R6Class(
}
return(predictions)
},

.fit_can_remove = c("lasso_fit", "x_basis"),
.required_packages = c("hal9001", "glmnet")
)
)
1 change: 1 addition & 0 deletions R/Lrnr_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Lrnr_xgboost <- R6Class(

return(predictions)
},
.fit_can_remove = c("raw", "call"),
.required_packages = c("xgboost")
)
)
10 changes: 10 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ true_obj_size <- function(obj) {
length(serialize(obj, NULL))
}

#' @keywords internal
check_fit_sizes <- function(fit) {
fo <- fit$fit_object
# see what's taking up the space
element_sizes <- sapply(fo, true_obj_size)
ranked <- sort(element_sizes / sum(element_sizes), decreasing = TRUE)

return(ranked)
}

################################################################################

#' Drop components from learner fits
Expand Down
6 changes: 3 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ sl3Options <- function(o, value) {
}
if (is.null(value)) {
res[o] <- list(NULL)
}
else {
} else {
res[[o]] <- value
}
options(res[o])
Expand All @@ -62,7 +61,8 @@ sl3Options <- function(o, value) {
"sl3.pcontinuous" = 0.05,
"sl3.max_p_missing" = 0.5,
"sl3.transform.offset" = TRUE,
"sl3.enable.future" = TRUE
"sl3.enable.future" = TRUE,
"sl3.reduce_fit" = FALSE
)
# for (i in setdiff(names(opts),names(options()))) {
# browser()
Expand Down
29 changes: 29 additions & 0 deletions prof_dt.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
learner,train (S).elapsed,predict (S).elapsed,size (MB),MSE,n,p,p_encoded
lasso,0.211999999999989,0.00900000000001455,0.0478677749633789,0.161776523228791,100,100,100
lasso_fast,0.0660000000002583,0.0100000000002183,0.0478677749633789,0.169428931381144,100,100,100
mean,0.0109999999999673,0.00399999999990541,0.000121116638183594,0.17908,100,100,100
ranger,0.0549999999998363,0.0160000000000764,0.436310768127441,0.165616521662889,100,100,100
ranger_small,0.0359999999996035,0.0100000000002183,0.0905466079711914,0.168278786405556,100,100,100
xgb,0.0370000000002619,0.00700000000006185,0.0262327194213867,0.174146900124846,100,100,100
glm,0.0419999999999163,0.00700000000006185,0.0339899063110352,0.477337871242782,100,100,100
hal_ls2,0.813000000000102,0.0279999999997926,0.760213851928711,0.179214823680041,100,100,100
lasso,2.33800000000019,0.01299999999992,0.0838642120361328,0.0823997719258026,1000,100,100
lasso_fast,1.41199999999981,0.0140000000001237,0.0838642120361328,0.0830177435791345,1000,100,100
mean,0.0209999999997308,0.00600000000031287,0.000121116638183594,0.1780702,1000,100,100
ranger,0.357000000000426,0.0879999999997381,4.20776462554932,0.14595598192,1000,100,100
ranger_small,0.103000000000065,0.0329999999999018,0.850159645080566,0.146991513661111,1000,100,100
xgb,0.141000000000076,0.01299999999992,0.0262327194213867,0.132485156061917,1000,100,100
glm,0.0949999999997999,0.00999999999976353,0.0408601760864258,0.0903512975129307,1000,100,100
hal_ls2,8.8130000000001,0.152000000000044,0.760213851928711,0.0983646681386519,1000,100,100
lasso,15.2350000000001,0.0520000000001346,0.0769224166870117,0.0659705855526775,10000,100,100
lasso_fast,4.76299999999992,0.0549999999998363,0.0769224166870117,0.0659733840218839,10000,100,100
mean,0.0109999999999673,0.0320000000001528,0.000121116638183594,0.17774657,10000,100,100
ranger,6.94100000000026,1.28200000000015,40.0050668716431,0.133503125161,10000,100,100
ranger_small,1.43199999999979,0.297999999999774,8.07391452789307,0.134473546638889,10000,100,100
xgb,0.909999999999854,0.0720000000001164,0.0262327194213867,0.111062539475018,10000,100,100
glm,0.493999999999687,0.0500000000001819,0.109524726867676,0.0660307579952773,10000,100,100
lasso_fast,58.0370000000003,0.372999999999593,0.0737161636352539,0.0654017031817335,100000,100,100
mean,0.0140000000001237,0.231999999999971,0.000121116638183594,0.1777693129,100000,100,100
ranger_small,51.5020000000004,4.70499999999993,77.7932748794556,0.124000491958333,100000,100,100
xgb,8.3149999999996,0.532999999999447,0.0262327194213867,0.108275682560827,100000,100,100
glm,4.62099999999919,0.353000000000066,0.796170234680176,0.0654282103103834,100000,100,100
44 changes: 44 additions & 0 deletions tests/testthat/test-reduce_fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

set.seed(1234)

# TODO: maybe check storage at different n to get rate
n <- 1e3
p <- 100
# these two define the DGP (randomly)
p_X <- runif(p, 0.2, 0.8)
beta <- rnorm(p)

# simulate from the DGP
X <- sapply(p_X, function(p_Xi) rbinom(n, 1, p_Xi))
p_Yx <- plogis(X %*% beta)
Y <- rbinom(n, 1, p_Yx)
data <- data.table(X, Y)

# generate the sl3 task and learner
outcome <- "Y"
covariates <- setdiff(names(data), outcome)
task <- make_sl3_Task(data, covariates, outcome)

options(sl3.verbose = TRUE)
options(sl3.reduce_fit = TRUE)
test_reduce_fit <- function(learner) {
fit <- learner$train(task)
print(sl3:::check_fit_sizes(fit))
if (!getOption("sl3.reduce_fit")) {
# if we aren't automatically reducing, do it manually
fit_object <- fit$reduce_fit()
}

still_present <- intersect(
names(fit$fit_object),
fit$.__enclos_env__$private$.fit_can_remove
)

expect_equal(length(still_present), 0)
}

test_reduce_fit(make_learner(Lrnr_glmnet))
test_reduce_fit(make_learner(Lrnr_ranger))
test_reduce_fit(make_learner(Lrnr_glm_fast))
test_reduce_fit(make_learner(Lrnr_xgboost))
test_reduce_fit(make_learner(Lrnr_hal9001))

0 comments on commit 86d610f

Please sign in to comment.