Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yardstick integration #120

Closed
topepo opened this issue Apr 16, 2021 · 11 comments
Closed

yardstick integration #120

topepo opened this issue Apr 16, 2021 · 11 comments

Comments

@topepo
Copy link
Contributor

topepo commented Apr 16, 2021

If you were good with a yardstick dependency, we could expand the list of metrics that can be used.

@bgreenwell
Copy link
Member

The more metrics the merrier! Would you be able to put together a PR for this too 😁? If not, I could probably get around to it later this month.

@topepo
Copy link
Contributor Author

topepo commented Dec 4, 2022

I'll be coming back to this soon (as well as some other PRs for model-based methods).

Off-hand, where do you handle the direction of the metric (i.e. maximize or minimize)?

@bgreenwell
Copy link
Member

Sounds good @topepo, there's an option called smaller_is_better in the call to vi_permute() that defaults to NULL with some logic in the case where a user picks one of the built-in metrics.

@topepo
Copy link
Contributor Author

topepo commented Jan 6, 2023

How much of the under-the-hood stuff would you be open to changing/refactoring?

For example, it would be good to put the directionality bits into the metrics file and so on.

I would also try to consolidate the metric checking into a separate function (or multiple function for user- and pre-defined metrics).

@bgreenwell
Copy link
Member

Hey @topepo, sorry for the late reply. I'm not against making useful changes to the package as long as it doesn't break backward compatibility too much. And good call out with the directionality info! Were you planning on contributing PRs for the metric checking?

@topepo
Copy link
Contributor Author

topepo commented Jan 19, 2023

A PR (or maybe more than one to make it more easily reviewed). I have a set of longish flights coming up and I'll probably work on this then. Maybe in the next few weeks.

@bgreenwell
Copy link
Member

Sounds good @topepo 👌

@bgreenwell
Copy link
Member

Just peeking through the package, it seems that using the *_vec() (e.g., roc_auc_vec()) family of functions would be relatively straightforward.

@bgreenwell
Copy link
Member

@topepo I'll be moving onto this package soon, just trying to tidy up some issues with fastshap first.

This was referenced May 7, 2023
@bgreenwell
Copy link
Member

Hey @topepo, started to integrate with yardstick if you want to take a look (changes are only on the devel branch and no unit tests or anything yet, but planning to rebuild the pkgdown site with plenty of examples to help):

library(ranger)
library(vip)
#> 
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#> 
#>     vi
library(yardstick)

# Complete (i.e., imputed) version of titanic data set
head(t3 <- titanic_mice[[1L]])
#>   survived pclass   age    sex sibsp parch
#> 1      yes      1 29.00 female     0     0
#> 2      yes      1  0.92   male     1     2
#> 3       no      1  2.00 female     1     2
#> 4       no      1 30.00   male     1     2
#> 5       no      1 25.00 female     1     2
#> 6      yes      1 48.00   male     0     0

#
# Predicting class labels
#

set.seed(1120)
rfo <- ranger(survived ~ ., data = t3, probability = FALSE)

# Prediction wrapper
pfun_rfo <- function(object, newdata) {
  predict(object, data = newdata)$predictions
}
pfun_rfo(rfo, newdata = head(t3))
#> [1] yes yes no  no  yes no 
#> Levels: no yes

# Should throw an error since ROC needs vector of probabilities
set.seed(1125)
vi_permute(
  rfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_rfo,
  metric = "roc_auc",
  smaller_is_better = FALSE
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> Error in `metric_fun()` at vip/R/vi_permute.R:311:2:
#> ! `estimate` should be a numeric vector, not a `factor` vector.
#> Backtrace:
#>     ▆
#>  1. ├─vip::vi_permute(...)
#>  2. └─vip:::vi_permute.default(...) at vip/R/vi_permute.R:148:2
#>  3.   └─yardstick (local) metric_fun(truth = train_y, estimate = pred_wrapper(object, newdata = train_x)) at vip/R/vi_permute.R:311:2
#>  4.     └─yardstick::check_prob_metric(truth, estimate, case_weights, estimator)
#>  5.       └─yardstick:::validate_factor_truth_matrix_estimate(...)
#>  6.         └─rlang::abort(...)

# Use yardstick function directly; need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
  rfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_rfo,
  metric = accuracy_vec,      # use yardstick function directly
  smaller_is_better = FALSE,  # needed when supplying a function
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.0764 0.00414
#> 2 age          0.0728 0.00837
#> 3 sex          0.221  0.0134 
#> 4 sibsp        0.0348 0.00369
#> 5 parch        0.0146 0.00363

# Use built-in yardstick function; no need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
  rfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_rfo,
  metric = "accuracy",      # uses yardstick internally
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.0764 0.00414
#> 2 age          0.0728 0.00837
#> 3 sex          0.221  0.0134 
#> 4 sibsp        0.0348 0.00369
#> 5 parch        0.0146 0.00363

#
# Predicting probabilites
#

set.seed(1120)
pfo <- ranger(survived ~ ., data = t3, probability = TRUE)  # probability forest

# Prediction wrapper
pfun_pfo <- function(object, newdata) {
  predict(object, data = newdata)$predictions[, "yes"]
}
pfun_pfo(pfo, newdata = head(t3))
#> [1] 0.9383245 0.8597809 0.6351732 0.3805831 0.7631647 0.3235726

# Use default event level; should throw a warning message
set.seed(1125)
vi_permute(
  pfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_pfo,
  metric = "roc_auc",
  nsim = 10
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass      -0.103  0.00455
#> 2 age         -0.0986 0.00678
#> 3 sex         -0.238  0.0124 
#> 4 sibsp       -0.0336 0.00339
#> 5 parch       -0.0226 0.00159

# Change the event level
set.seed(1125)
vi_permute(
  pfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_pfo,
  metric = "roc_auc",
  event_level = "second",
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.103  0.00455
#> 2 age          0.0986 0.00678
#> 3 sex          0.238  0.0124 
#> 4 sibsp        0.0336 0.00339
#> 5 parch        0.0226 0.00159

# Could also do this with a wrapper function
mfun <- function(truth, estimate) {
  roc_auc_vec(truth = truth, estimate = estimate, event_level = "second")
}
set.seed(1125)
vi_permute(
  pfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_pfo,
  metric = mfun,
  smaller_is_better = FALSE,
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.103  0.00455
#> 2 age          0.0986 0.00678
#> 3 sex          0.238  0.0124 
#> 4 sibsp        0.0336 0.00339
#> 5 parch        0.0226 0.00159

Created on 2023-05-08 with reprex v2.0.2

@bgreenwell
Copy link
Member

In devel now and will be part of next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants