-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for luz
#187
Support for luz
#187
Conversation
Thank you so much for this contribution @dfalbel!
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for contributing this! Let me know if you have any questions on this feedback, and I would very much welcome your input and ideas on #55.
R/luz.R
Outdated
#' @rdname vetiver_create_meta | ||
#' @export | ||
vetiver_create_meta.luz_module_fitted <- function(model, metadata) { | ||
pkgs <- c("luz", model$model$required_pkgs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me more about how this works in luz? I have trained some of the example models and they do require me to have, for example, torch and/or torchvision loaded, but then they are not stored in this slot. Instead, I see:
model$model$required_pkgs
#> NULL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In luz we don't try to be smart about capturing used packages, but users can optionally set this field in the nn_module
so it's available as metadata. Eg, one can do:
module <- torch::nn_module(
initialize = function(in_features, out_features) {
self$linear <- torch::nn_linear(10, 10)
},
forward = function(x) {
self$linear(x)
},
required_pkgs = c("torch", "torchvision")
)
We could try to traverse the forward
expression and find functions calls that come from other packages, but I feel this can still have many edge cases and is kind of error prone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would they always need torch? Should we include that there? I think this is about what needs to be installed and attached for predictions to work. Getting the right packages installed for the deployment is a big part of what vetiver aims to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch don't necessary need to be attached, but it definitely needs to be installed, but it should already be as it's a hard dependency for luz. In the above example, torch wouldn't need to attached for predictions to work.
#' @rdname handler_startup | ||
#' @export | ||
handler_predict.luz_module_fitted <- function(vetiver_model, ...) { | ||
force(vetiver_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me a little more about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we return the closure directly, without evaluating vetiver_model
, it will not be in the in the function env until the first closure call - but at this point vetiver_model
could potentially have been garbage collected. This might not be the case for vetiver though, just feels like a good practice to force
before returning the closure.
I'm trying to avoid something like this:
f <- function(a, force) {
if (force) force(a)
function() {
a + 1
}
}
b <- 1
fun_f <- f(a = b, force = TRUE)
fun_nf <- f(a = b, force = FALSE)
rm(b);gc()
#> used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
#> Ncells 674744 36.1 1413787 75.6 NA 710975 38.0
#> Vcells 1201226 9.2 8388608 64.0 32768 1888973 14.5
fun_f()
#> [1] 2
fun_nf()
#> Error in fun_nf(): object 'b' not found
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the details!
tests/testthat/test-luz.R
Outdated
expect_error(predict(v, as.array(torch::torch_randn(10, 2))), regex = "dim error") | ||
}) | ||
|
||
test_that("can call endpoints", { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check out the approach for testing a local plumber session I have set up in the package already to use instead of this. Unfortunately it's not practical to set up APIs for all the model types that run in CI (just takes too long for the API to come up on some architectures) so a test like this will need to skip on CI (as well of course on CRAN).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed that test for now. I wasn't able to use local_plumber_session
because an unbundled model is passed to the callr
session, and that breaks the luz model. We could perhaps 'unbundle' on the first call to the API instead? But probably should pe part of another PR.
@juliasilge Thank you very much for the review and suggestions. I simplified the PR to make luz support very similar to Keras. I'll think about multi-output support and post on #55 |
I made a little example for returning higher dimensional tensors and put it in library(vetiver)
endpoint <- vetiver_endpoint("http://127.0.0.1:8080/predict")
scaled_cars <- scale(as.matrix(mtcars))
x_test <- scaled_cars[26:32, 2:ncol(scaled_cars)]
predict(endpoint, data.frame(x_test[1:2,]))
#> # A tibble: 2 × 1
#> preds
#> <list>
#> 1 <dbl [3 × 64 × 64]>
#> 2 <dbl [3 × 64 × 64]> Created on 2023-04-03 with reprex v2.0.2 @dfalbel would you mind taking a look at this again and seeing if you have any feedback (other than, of course, how to extend the prototype checking to non-rectangular data, which we can handle separately)? |
Ah, I went to deploy one of these models on Connect and realized that we haven't set up the torch installation for the API. 🙈 What do you think is the best way to go about this @dfalbel? The way we handle installing keras is via a What would be a good way to handle this for torch? What do you all do for installing torch on Connect typically? |
In theory, just setting the env var |
Does that mean it will install torch every time the content starts, i.e. the API starts up? That's not ideal. How do you all typically install torch into content when you are deploying on Connect? Do you have an example I can look at? We would want the install to happen when the content deploys, not each time it starts up. |
Hi 👋!
This is a first pass at supporting
luz
in vetiver.There are a few things that I'd like to ask the best way to proceed:
Calling
predict
in the result ofvetiver_model
doesn't yield the same structure as calling predict in an endpoint containing a luz model - which can be confusing. I wonder if we should enforce that somehow, in this case, I think it would be nice if the same validations happening inhandler_predict.luz_module_fitted
also happened forpredict.vetiver_model
.Outputs of luz models can be arrays with arbitrary dimensions and vetiver enforces a data frame output. To handle this, we are returning a dataframe with an
array
column, which helps preserving the output dimensions. However, the json serializer and de-serializer somehow breaks the array column:I wonder if there's a way to safely override the de-serializer for those models so the original structure is preserved.