Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for k-Prototypes clustering to vetiver #163

Merged
merged 10 commits into from
Dec 2, 2022

Conversation

galen-ft
Copy link
Contributor

The PR adds support for k-Prototypes clustering models to vetiver.

While the current PR is related to tidymodels/butcher#235, both PRs can be merged separately.

@juliasilge
Copy link
Member

Thanks so much for looking at this! 🙌 I haven't thought very deeply about unsupervised algorithms in general or clustering algorithms in particular.

What are you expecting to be returned when a clustering algorithm is deployed, in production, etc?

library(tidyverse)
library(clustMixType)
kproto_model <- clustMixType::kproto(
    x = iris,
    k = 2,
    verbose = FALSE
)

predict(kproto_model, iris %>% slice_sample(n = 2))
#> $cluster
#> [1] 1 2
#> 
#> $dists
#>            [,1]      [,2]
#> [1,]  0.0958099 17.923974
#> [2,] 13.0284046  2.054489

Created on 2022-11-21 with reprex v2.0.2

This doesn't have one thing to be returned for each observation currently (like a class or predicted outcome), or one row of things (like probabilities of a set of outcomes). Are you only wanting the cluster identifications?

Overall, this is different enough from the expectations of vetiver that I wonder about housing such deployment strategies for unsupervised ML in a different package.

@galen-ft
Copy link
Contributor Author

Thank you for taking the time to look into the PR! Your feedback is certainly useful.

Indeed, it is the cluster identifiers (or memberships) I am after. One way to obtain those would be to update handler_predict.kproto() to return list(.pred = ret[["cluster"]]) instead of list(.pred = ret). This would give us a single vector with one prediction for each input observation.

If my proposed change would make extending vetiver to support this type of clustering model more inline with the project's expectations, I would be happy to update the PR.

Disclaimer: I am not the original developer of clustMixType and have no control over its source code. I would simply like to use it within a company project and it would be great if we could do it through vetiver.

@juliasilge
Copy link
Member

This model can't predict on only one observation; it looks like there is a bug in clustMixType (probably a missing drop = FALSE somewhere).

library(clustMixType)

data(crickets, package = "modeldata")
crickets <- as.data.frame(crickets)

kp <- kproto(
    crickets,
    k = 3,
    lambda = lambdaest(crickets),
    verbose = FALSE
)
#> Numeric variances:
#>      temp      rate 
#>  14.61703 286.07249 
#> Average numeric variance: 150.3448 
#> 
#> Heuristic for categorical variables: (method = 1) 
#>   species 
#> 0.4953174 
#> Average categorical variation: 0.4953174 
#> 
#> Estimated lambda: 303.5322

predict(kp, newdata = crickets[5:6,])
#> $cluster
#> [1] 3 3
#> 
#> $dists
#>          [,1]     [,2]     [,3]
#> [1,] 1027.106 334.0745 94.40806
#> [2,] 1081.126 346.0079 76.67472
predict(kp, newdata = crickets[5,])
#> Error in rowSums(d2): 'x' must be an array of at least two dimensions

Created on 2022-11-30 with reprex v2.0.2

Have you run into this before @galen-ft? Are you aware of any workarounds?

@galen-ft
Copy link
Contributor Author

galen-ft commented Dec 2, 2022

Hi Julia, many thanks for looking into this!

Yes, indeed, it is a bug with the clustMixType package which I encountered while working on the current PR.

I already reported it here and also wrote an email with a suggested fix to the author of the package.

Even though the best place to fix this would be inside clustMixType::predict.kproto(), another workaround does come to mind until the fix is accepted. That is, handler_predict.kproto() can be updated so that the request handler it returns attaches a dummy row to the bottom of newdata whenever it has a single observation. Then the function can simply return all but the last (dummy) prediction.

If that seems all right with you, I'll be happy to make the update.

In any case, I'll share more information as soon as I receive a reply from the professor who wrote the package.

@juliasilge
Copy link
Member

This is how it works with vetiver_endpoint():

library(clustMixType)
library(vetiver)
library(plumber)

data(crickets, package = "modeldata")
crickets <- as.data.frame(crickets)

kp <- kproto(
    crickets,
    k = 3,
    lambda = lambdaest(crickets),
    verbose = FALSE
)
#> Numeric variances:
#>      temp      rate 
#>  14.61703 286.07249 
#> Average numeric variance: 150.3448 
#> 
#> Heuristic for categorical variables: (method = 1) 
#>   species 
#> 0.4953174 
#> Average categorical variation: 0.4953174 
#> 
#> Estimated lambda: 303.5322
v <- vetiver_model(kp, "crickets-cluster")

## for testing API in a new process:
local_clustMixType_session <- function(pr) {
    rs <- callr::r_session$new()
    rs$call(
        function(pr) {
            library(clustMixType)
            plumber::pr_run(pr = pr, port = 8088)
        },
        args = list(pr = pr)
    )
    withr::defer(rs$close())
    rs
}

pr <- pr() %>% vetiver_api(v, debug = TRUE)
rs <- local_clustMixType_session(pr) 

endpoint <- vetiver_endpoint("http://127.0.0.1:8088/predict")
predict(endpoint, new_data = crickets[5:6,])
#> # A tibble: 2 × 1
#>   .pred
#>   <int>
#> 1     3
#> 2     3

Created on 2022-12-02 with reprex v2.0.2

@juliasilge
Copy link
Member

Even though the best place to fix this would be inside clustMixType::predict.kproto(), another workaround does come to mind until the fix is accepted. That is, handler_predict.kproto() can be updated so that the request handler it returns attaches a dummy row to the bottom of newdata whenever it has a single observation. Then the function can simply return all but the last (dummy) prediction.

If that seems all right with you, I'll be happy to make the update.

Hmmm, I don't think I want to maintain code to account for this kind of problem in another package; let's hope the other maintainer is willing to do a fix.

Copy link
Member

@juliasilge juliasilge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR @galen-ft!

For any folks coming by to see this later, as of today, prediction for one observation is not working due to a bug in clustMixType.

@juliasilge juliasilge merged commit 8ac5138 into rstudio:main Dec 2, 2022
@g-rho
Copy link

g-rho commented Dec 3, 2022

will be fixed soon.

@juliasilge
Copy link
Member

Glad to hear that @g-rho!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants