Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for luz #187

Merged
merged 38 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ef34f20
First pass at `luz` support.
dfalbel Mar 10, 2023
21d188c
Update test expectations
dfalbel Mar 10, 2023
b8460b9
try with non retangular data
dfalbel Mar 10, 2023
0e8b5fa
Add more expectations.
dfalbel Mar 10, 2023
e0ae2c1
Handle skips + torch installation.
dfalbel Mar 10, 2023
217185d
Force torch installation if tests are to be executed.
dfalbel Mar 10, 2023
c6562b6
retry more times, don't wait more than 10s between retries.
dfalbel Mar 10, 2023
57ffbba
Review: Add parameter count information.
dfalbel Mar 16, 2023
8d21a05
use `tensors_to_array` directly.
dfalbel Mar 16, 2023
8679440
Take approach very similar to keras and support only rectangular data…
dfalbel Mar 16, 2023
4c75d50
convert new data to array before passing to luz.
dfalbel Mar 16, 2023
abe11d1
Update tests
juliasilge Mar 18, 2023
9f29d68
Should be `prototype` here
juliasilge Mar 18, 2023
ed2d042
Update documentation for luz
juliasilge Apr 3, 2023
66344c3
Use tibble, for better JSON serialization
juliasilge Apr 3, 2023
d625d37
Add luz example to `inst/`
juliasilge Apr 3, 2023
66ce9ed
Fix variable name in little example
juliasilge Apr 3, 2023
9276fa9
Try using any::ranger after new release
juliasilge Apr 3, 2023
8b84b36
New method to generate `.Renviron` for deployment
juliasilge Apr 5, 2023
6b7c735
Update tests
juliasilge Apr 5, 2023
43960f0
Redocument
juliasilge Apr 5, 2023
ff83bd6
Attach torch at startup
juliasilge Apr 5, 2023
c0a2956
Namespace torch functions
juliasilge Apr 6, 2023
40e637e
Merge branch 'main' into luz-support
juliasilge Apr 6, 2023
63d9ffb
Try without torch here
juliasilge Apr 6, 2023
94dd35a
Handle ranger problems separately
juliasilge Apr 6, 2023
eedce33
Update NEWS
juliasilge Apr 6, 2023
adf8cb0
Merged origin/main into dfalbel-luz-support
juliasilge Apr 6, 2023
31d484e
Add torch back to `required_pkgs`
juliasilge Apr 6, 2023
28206dc
Test what is happening at startup
juliasilge Apr 6, 2023
1f2fd90
Update snapshot
juliasilge Apr 6, 2023
c11949a
Try attachNamespace
juliasilge Apr 6, 2023
6f7b30a
Try using vetiver's `attach_pkgs()`
juliasilge Apr 6, 2023
7f367c9
Try mapping through `required_pkgs`
juliasilge Apr 6, 2023
6799654
Update test
juliasilge Apr 6, 2023
7ef9d5d
Fix bug in namespace handling
juliasilge Apr 6, 2023
abcba74
Back to using `attach_pkgs()`
juliasilge Apr 6, 2023
a35d8ad
Back to using `library(torch)` for example
juliasilge Apr 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ Suggests:
vdiffr,
workflows,
xgboost,
yardstick
yardstick,
torch,
luz
VignetteBuilder:
knitr
Config/Needs/website: tidyverse/tidytemplate
Expand Down
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ S3method(handler_predict,glm)
S3method(handler_predict,keras.engine.training.Model)
S3method(handler_predict,kproto)
S3method(handler_predict,lm)
S3method(handler_predict,luz_module_fitted)
S3method(handler_predict,model_stack)
S3method(handler_predict,ranger)
S3method(handler_predict,recipe)
Expand All @@ -26,6 +27,7 @@ S3method(handler_startup,Learner)
S3method(handler_startup,default)
S3method(handler_startup,gam)
S3method(handler_startup,keras.engine.training.Model)
S3method(handler_startup,luz_module_fitted)
S3method(handler_startup,model_stack)
S3method(handler_startup,ranger)
S3method(handler_startup,recipe)
Expand All @@ -48,6 +50,7 @@ S3method(vetiver_create_description,glm)
S3method(vetiver_create_description,keras.engine.training.Model)
S3method(vetiver_create_description,kproto)
S3method(vetiver_create_description,lm)
S3method(vetiver_create_description,luz_module_fitted)
S3method(vetiver_create_description,model_stack)
S3method(vetiver_create_description,ranger)
S3method(vetiver_create_description,recipe)
Expand All @@ -59,6 +62,7 @@ S3method(vetiver_create_meta,default)
S3method(vetiver_create_meta,gam)
S3method(vetiver_create_meta,keras.engine.training.Model)
S3method(vetiver_create_meta,kproto)
S3method(vetiver_create_meta,luz_module_fitted)
S3method(vetiver_create_meta,model_stack)
S3method(vetiver_create_meta,ranger)
S3method(vetiver_create_meta,recipe)
Expand All @@ -72,6 +76,7 @@ S3method(vetiver_prepare_model,glm)
S3method(vetiver_prepare_model,keras.engine.training.Model)
S3method(vetiver_prepare_model,kproto)
S3method(vetiver_prepare_model,lm)
S3method(vetiver_prepare_model,luz_module_fitted)
S3method(vetiver_prepare_model,model_stack)
S3method(vetiver_prepare_model,ranger)
S3method(vetiver_prepare_model,recipe)
Expand All @@ -85,6 +90,7 @@ S3method(vetiver_ptype,glm)
S3method(vetiver_ptype,keras.engine.training.Model)
S3method(vetiver_ptype,kproto)
S3method(vetiver_ptype,lm)
S3method(vetiver_ptype,luz_module_fitted)
S3method(vetiver_ptype,model_stack)
S3method(vetiver_ptype,ranger)
S3method(vetiver_ptype,recipe)
Expand All @@ -93,6 +99,8 @@ S3method(vetiver_ptype,workflow)
S3method(vetiver_ptype,xgb.Booster)
S3method(vetiver_python_requirements,default)
S3method(vetiver_python_requirements,keras.engine.training.Model)
S3method(vetiver_renviron_requirements,default)
S3method(vetiver_renviron_requirements,luz_module_fitted)
export(api_spec)
export(attach_pkgs)
export(augment)
Expand Down Expand Up @@ -129,6 +137,7 @@ export(vetiver_prepare_docker)
export(vetiver_prepare_model)
export(vetiver_ptype)
export(vetiver_python_requirements)
export(vetiver_renviron_requirements)
export(vetiver_sm_build)
export(vetiver_sm_delete)
export(vetiver_sm_endpoint)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# vetiver (development version)

* Added support for keras (#164) and recipes (#179).
* Added support for keras (#164), recipes (#179), and luz (#187, @dfalbel).

* Moved where `required_pkgs` metadata is stored remotely, from the binary blob to plain text YAML (#176).

Expand Down
8 changes: 4 additions & 4 deletions R/attach-pkgs.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@
#' try(attach_pkgs(c("bloopy", "readr")))
#'
attach_pkgs <- function(pkgs) {
attached <- paste0("package:", pkgs) %in% search()
pkgs <- pkgs[!attached]
namespace_handling(pkgs, attachNamespace, "Package(s) could not be attached:")
}

#' @export
#' @rdname attach_pkgs
load_pkgs <- function(pkgs) {
loaded <- map_lgl(pkgs, isNamespaceLoaded)
pkgs <- pkgs[!loaded]
namespace_handling(pkgs, loadNamespace, "Namespace(s) could not be loaded:")
}

namespace_handling <- function(pkgs, func, error_msg) {
loaded <- map_lgl(pkgs, isNamespaceLoaded)
pkgs <- pkgs[!loaded]

safe_load <- safely(withr::with_preserve_seed(func))

did_load <- map(pkgs, safe_load)
bad <- compact(map(did_load, "error"))
bad <- map_chr(bad, "package")
Expand Down
2 changes: 1 addition & 1 deletion R/keras.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ handler_predict.keras.engine.training.Model <- function(vetiver_model, ...) {

}

#' @rdname vetiver_write_plumber
#' @rdname vetiver_python_requirements
#' @export
vetiver_python_requirements.keras.engine.training.Model <- function(model) {
system.file("requirements/keras-requirements.txt", package = "vetiver")
Expand Down
65 changes: 65 additions & 0 deletions R/luz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#' @rdname vetiver_create_description
#' @export
vetiver_create_description.luz_module_fitted <- function(model) {
n_parameters <- lapply(model$model$parameters, function(x) prod(x$shape))
n_parameters <- do.call(sum, n_parameters)
n_parameters <- formatC(n_parameters, big.mark = ",", format = "d")
glue("A luz module with {n_parameters} parameters")
}

#' @rdname vetiver_create_meta
#' @export
vetiver_create_meta.luz_module_fitted <- function(model, metadata) {
pkgs <- c("luz", "torch", model$model$required_pkgs)
vetiver_meta(metadata, required_pkgs = pkgs)
}

#' @rdname vetiver_create_ptype
#' @export
vetiver_ptype.luz_module_fitted <- function(model, ...) {
juliasilge marked this conversation as resolved.
Show resolved Hide resolved
rlang::check_dots_used()
dots <- list(...)
check_ptype_data(dots)
ptype <- vctrs::vec_ptype(dots$prototype_data)
tibble::as_tibble(ptype)
}

#' @rdname vetiver_create_description
#' @export
vetiver_prepare_model.luz_module_fitted <- function(model) {
bundle::bundle(model)
}

#' @rdname handler_startup
#' @export
handler_startup.luz_module_fitted <- function(vetiver_model) {
attach_pkgs(vetiver_model$metadata$required_pkgs)
}

#' @rdname handler_startup
#' @export
handler_predict.luz_module_fitted <- function(vetiver_model, ...) {
force(vetiver_model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me a little more about this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we return the closure directly, without evaluating vetiver_model, it will not be in the in the function env until the first closure call - but at this point vetiver_model could potentially have been garbage collected. This might not be the case for vetiver though, just feels like a good practice to force before returning the closure.

I'm trying to avoid something like this:

f <- function(a, force) {
    if (force) force(a)
    function() {
        a + 1
    }
}

b <- 1
fun_f <- f(a = b, force = TRUE)
fun_nf <- f(a = b, force = FALSE)
rm(b);gc()
#>           used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
#> Ncells  674744 36.1    1413787 75.6         NA   710975 38.0
#> Vcells 1201226  9.2    8388608 64.0      32768  1888973 14.5

fun_f()
#> [1] 2
fun_nf()
#> Error in fun_nf(): object 'b' not found

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details!

function(req) {
new_data <- vetiver_type_convert(req$body, vetiver_model$prototype)
new_data <- if (is.data.frame(new_data)) as.matrix(new_data) else new_data
preds <- tensors_to_array(predict(vetiver_model$model, new_data))
tibble::tibble(preds)
}
}

tensors_to_array <- function(x) {
if (is.list(x)) {
lapply(x, tensors_to_array)
} else if (inherits(x, "torch_tensor")) {
as.array(x$cpu())
} else {
x
}
}

#' @rdname vetiver_python_requirements
#' @export
vetiver_renviron_requirements.luz_module_fitted <- function(model) {
system.file("requirements/luz-renviron.txt", package = "vetiver")
}
7 changes: 4 additions & 3 deletions R/vetiver-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
#' as your training data (perhaps with [hardhat::scream()]) and/or simulating
#' data to avoid leaking PII via your deployed model.
#'
#' Some models, like [ranger::ranger()] and [keras](https://tensorflow.rstudio.com/)
#' models, *require* that you pass in example training data as `prototype_data`
#' Some models, like [ranger::ranger()], [keras](https://tensorflow.rstudio.com/),
#' and [luz (torch)](https://torch.mlverse.org/),
#' *require* that you pass in example training data as `prototype_data`
#' or else explicitly set `save_prototype = FALSE`. For non-rectangular data
#' input to models, such as image input for a keras model, we currently
#' input to models, such as image input for a keras or torch model, we currently
#' recommend that you turn off prototype checking via `save_prototype = FALSE`.
#'
#' @return A new `vetiver_model` object.
Expand Down
47 changes: 35 additions & 12 deletions R/write-plumber.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
#' version. You can override this default behavior by choosing a specific
#' `version`.
#'
#' This function uses `vetiver_python_requirements()` internally to create a
#' minimal Python `requirements.txt` for models that need it.
#'
#' @return
#' The content of the `plumber.R` file, invisibly.
#'
Expand Down Expand Up @@ -57,7 +54,7 @@ vetiver_write_plumber <- function(board, name, version = NULL,
v <- vetiver_pin_read(board, name)
}

write_python_requirements(v$model, file)
write_extra_requirements(v$model, file)

load_infra_pkgs <- glue_collapse(glue("library({infra_pkgs})"), sep = "\n")
load_required_pkgs <- glue_required_pkgs(v$metadata$required_pkgs, rsconnect)
Expand Down Expand Up @@ -126,27 +123,53 @@ choose_version <- function(df) {
version[["version"]]
}

write_python_requirements <- function(model, file) {
write_extra_requirements <- function(model, file) {
model <- bundle::unbundle(model)
path_to_requirements <- vetiver_python_requirements(model)
if (!is.null(path_to_requirements)) {
path_to_py_requirements <- vetiver_python_requirements(model)
file_copy_requirements(path_to_py_requirements, file, "requirements.txt")
path_to_renviron_requirements <- vetiver_renviron_requirements(model)
file_copy_requirements(path_to_renviron_requirements, file, ".Renviron")
TRUE
}

file_copy_requirements <- function(requirements, plumber_file, new_name) {
if (!is.null(requirements)) {
fs::file_copy(
path_to_requirements,
fs::path(fs::path_dir(file), "requirements.txt"),
requirements,
fs::path(fs::path_dir(plumber_file), new_name),
overwrite = TRUE
)
}
path_to_requirements
requirements
}

#' @rdname vetiver_write_plumber
#' Use extra files required for deployment
#'
#' Create files required for deploying an app generated via
#' [vetiver_write_plumber()], such as a Python `requirements.txt` or an
#' `.Renviron`
#'
#' @inheritParams vetiver_model
#' @export
#' @keywords internal
vetiver_python_requirements <- function(model) {
UseMethod("vetiver_python_requirements")
}

#' @rdname vetiver_write_plumber
#' @rdname vetiver_python_requirements
#' @export
vetiver_python_requirements.default <- function(model) {
NULL
}

#' @rdname vetiver_python_requirements
#' @export
vetiver_renviron_requirements <- function(model) {
UseMethod("vetiver_renviron_requirements")
}

#' @rdname vetiver_python_requirements
#' @export
vetiver_renviron_requirements.default <- function(model) {
NULL
}
39 changes: 39 additions & 0 deletions inst/mtcars_luz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
library(vetiver)
library(plumber)
library(torch)

scaled_cars <- as.matrix(mtcars) %>% scale()
x_test <- scaled_cars[26:32, 2:ncol(scaled_cars)]
x_train <- scaled_cars[1:25, 2:ncol(scaled_cars)]
y_train <- scaled_cars[1:25, 1, drop = FALSE]

set.seed(1)

luz_module <- torch::nn_module(
initialize = function(in_features, out_features) {
self$linear <- nn_linear(in_features, out_features)
},
forward = function(x) {
if (self$training) {
self$linear(x)
} else {
torch_randn(dim(x)[1], 3, 64, 64, device = self$linear$weight$device)
}

}
)

luz_fit <- luz_module %>%
luz::setup(loss = torch::nnf_mse_loss, optimizer = torch::optim_sgd) %>%
luz::set_hparams(in_features = ncol(x_train), out_features = 1) %>%
luz::set_opt_hparams(lr = 0.01) %>%
luz::fit(list(x_train, y_train), verbose = FALSE, dataloader_options = list(batch_size = 5))

v <- vetiver_model(luz_fit, "cars-luz", prototype_data = data.frame(x_train)[1,])
pr() %>% vetiver_api(v, debug = TRUE) %>% pr_run(port = 8080)

##### in new session: ##########################################################
# library(vetiver)
# endpoint <- vetiver_endpoint("http://127.0.0.1:8080/predict")
# x_test <- dplyr::slice(data.frame(scale(mtcars)), 26:32)
# predict(endpoint, x_test[1:2,])
2 changes: 2 additions & 0 deletions inst/requirements/luz-renviron.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
TORCH_INSTALL=1
TORCH_HOME="libtorch/"
10 changes: 8 additions & 2 deletions man/handler_startup.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/vetiver_create_description.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading