-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for k-Prototypes clustering to vetiver #163
Conversation
Thanks so much for looking at this! 🙌 I haven't thought very deeply about unsupervised algorithms in general or clustering algorithms in particular. What are you expecting to be returned when a clustering algorithm is deployed, in production, etc? library(tidyverse)
library(clustMixType)
kproto_model <- clustMixType::kproto(
x = iris,
k = 2,
verbose = FALSE
)
predict(kproto_model, iris %>% slice_sample(n = 2))
#> $cluster
#> [1] 1 2
#>
#> $dists
#> [,1] [,2]
#> [1,] 0.0958099 17.923974
#> [2,] 13.0284046 2.054489 Created on 2022-11-21 with reprex v2.0.2 This doesn't have one thing to be returned for each observation currently (like a class or predicted outcome), or one row of things (like probabilities of a set of outcomes). Are you only wanting the cluster identifications? Overall, this is different enough from the expectations of vetiver that I wonder about housing such deployment strategies for unsupervised ML in a different package. |
Thank you for taking the time to look into the PR! Your feedback is certainly useful. Indeed, it is the cluster identifiers (or memberships) I am after. One way to obtain those would be to update handler_predict.kproto() to return If my proposed change would make extending vetiver to support this type of clustering model more inline with the project's expectations, I would be happy to update the PR. Disclaimer: I am not the original developer of |
This model can't predict on only one observation; it looks like there is a bug in clustMixType (probably a missing library(clustMixType)
data(crickets, package = "modeldata")
crickets <- as.data.frame(crickets)
kp <- kproto(
crickets,
k = 3,
lambda = lambdaest(crickets),
verbose = FALSE
)
#> Numeric variances:
#> temp rate
#> 14.61703 286.07249
#> Average numeric variance: 150.3448
#>
#> Heuristic for categorical variables: (method = 1)
#> species
#> 0.4953174
#> Average categorical variation: 0.4953174
#>
#> Estimated lambda: 303.5322
predict(kp, newdata = crickets[5:6,])
#> $cluster
#> [1] 3 3
#>
#> $dists
#> [,1] [,2] [,3]
#> [1,] 1027.106 334.0745 94.40806
#> [2,] 1081.126 346.0079 76.67472
predict(kp, newdata = crickets[5,])
#> Error in rowSums(d2): 'x' must be an array of at least two dimensions Created on 2022-11-30 with reprex v2.0.2 Have you run into this before @galen-ft? Are you aware of any workarounds? |
Hi Julia, many thanks for looking into this! Yes, indeed, it is a bug with the clustMixType package which I encountered while working on the current PR. I already reported it here and also wrote an email with a suggested fix to the author of the package. Even though the best place to fix this would be inside clustMixType::predict.kproto(), another workaround does come to mind until the fix is accepted. That is, handler_predict.kproto() can be updated so that the request handler it returns attaches a dummy row to the bottom of If that seems all right with you, I'll be happy to make the update. In any case, I'll share more information as soon as I receive a reply from the professor who wrote the package. |
This is how it works with library(clustMixType)
library(vetiver)
library(plumber)
data(crickets, package = "modeldata")
crickets <- as.data.frame(crickets)
kp <- kproto(
crickets,
k = 3,
lambda = lambdaest(crickets),
verbose = FALSE
)
#> Numeric variances:
#> temp rate
#> 14.61703 286.07249
#> Average numeric variance: 150.3448
#>
#> Heuristic for categorical variables: (method = 1)
#> species
#> 0.4953174
#> Average categorical variation: 0.4953174
#>
#> Estimated lambda: 303.5322
v <- vetiver_model(kp, "crickets-cluster")
## for testing API in a new process:
local_clustMixType_session <- function(pr) {
rs <- callr::r_session$new()
rs$call(
function(pr) {
library(clustMixType)
plumber::pr_run(pr = pr, port = 8088)
},
args = list(pr = pr)
)
withr::defer(rs$close())
rs
}
pr <- pr() %>% vetiver_api(v, debug = TRUE)
rs <- local_clustMixType_session(pr)
endpoint <- vetiver_endpoint("http://127.0.0.1:8088/predict")
predict(endpoint, new_data = crickets[5:6,])
#> # A tibble: 2 × 1
#> .pred
#> <int>
#> 1 3
#> 2 3 Created on 2022-12-02 with reprex v2.0.2 |
Hmmm, I don't think I want to maintain code to account for this kind of problem in another package; let's hope the other maintainer is willing to do a fix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this PR @galen-ft!
For any folks coming by to see this later, as of today, prediction for one observation is not working due to a bug in clustMixType.
will be fixed soon. |
Glad to hear that @g-rho! |
The PR adds support for k-Prototypes clustering models to vetiver.
While the current PR is related to tidymodels/butcher#235, both PRs can be merged separately.