-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
yardstick integration #120
Comments
The more metrics the merrier! Would you be able to put together a PR for this too 😁? If not, I could probably get around to it later this month. |
I'll be coming back to this soon (as well as some other PRs for model-based methods). Off-hand, where do you handle the direction of the metric (i.e. maximize or minimize)? |
Sounds good @topepo, there's an option called |
How much of the under-the-hood stuff would you be open to changing/refactoring? For example, it would be good to put the directionality bits into the metrics file and so on. I would also try to consolidate the metric checking into a separate function (or multiple function for user- and pre-defined metrics). |
Hey @topepo, sorry for the late reply. I'm not against making useful changes to the package as long as it doesn't break backward compatibility too much. And good call out with the directionality info! Were you planning on contributing PRs for the metric checking? |
A PR (or maybe more than one to make it more easily reviewed). I have a set of longish flights coming up and I'll probably work on this then. Maybe in the next few weeks. |
Sounds good @topepo 👌 |
Just peeking through the package, it seems that using the |
Hey @topepo, started to integrate with yardstick if you want to take a look (changes are only on the devel branch and no unit tests or anything yet, but planning to rebuild the pkgdown site with plenty of examples to help): library(ranger)
library(vip)
#>
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#>
#> vi
library(yardstick)
# Complete (i.e., imputed) version of titanic data set
head(t3 <- titanic_mice[[1L]])
#> survived pclass age sex sibsp parch
#> 1 yes 1 29.00 female 0 0
#> 2 yes 1 0.92 male 1 2
#> 3 no 1 2.00 female 1 2
#> 4 no 1 30.00 male 1 2
#> 5 no 1 25.00 female 1 2
#> 6 yes 1 48.00 male 0 0
#
# Predicting class labels
#
set.seed(1120)
rfo <- ranger(survived ~ ., data = t3, probability = FALSE)
# Prediction wrapper
pfun_rfo <- function(object, newdata) {
predict(object, data = newdata)$predictions
}
pfun_rfo(rfo, newdata = head(t3))
#> [1] yes yes no no yes no
#> Levels: no yes
# Should throw an error since ROC needs vector of probabilities
set.seed(1125)
vi_permute(
rfo,
train = t3,
target = "survived",
pred_wrapper = pfun_rfo,
metric = "roc_auc",
smaller_is_better = FALSE
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> Error in `metric_fun()` at vip/R/vi_permute.R:311:2:
#> ! `estimate` should be a numeric vector, not a `factor` vector.
#> Backtrace:
#> ▆
#> 1. ├─vip::vi_permute(...)
#> 2. └─vip:::vi_permute.default(...) at vip/R/vi_permute.R:148:2
#> 3. └─yardstick (local) metric_fun(truth = train_y, estimate = pred_wrapper(object, newdata = train_x)) at vip/R/vi_permute.R:311:2
#> 4. └─yardstick::check_prob_metric(truth, estimate, case_weights, estimator)
#> 5. └─yardstick:::validate_factor_truth_matrix_estimate(...)
#> 6. └─rlang::abort(...)
# Use yardstick function directly; need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
rfo,
train = t3,
target = "survived",
pred_wrapper = pfun_rfo,
metric = accuracy_vec, # use yardstick function directly
smaller_is_better = FALSE, # needed when supplying a function
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.0764 0.00414
#> 2 age 0.0728 0.00837
#> 3 sex 0.221 0.0134
#> 4 sibsp 0.0348 0.00369
#> 5 parch 0.0146 0.00363
# Use built-in yardstick function; no need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
rfo,
train = t3,
target = "survived",
pred_wrapper = pfun_rfo,
metric = "accuracy", # uses yardstick internally
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.0764 0.00414
#> 2 age 0.0728 0.00837
#> 3 sex 0.221 0.0134
#> 4 sibsp 0.0348 0.00369
#> 5 parch 0.0146 0.00363
#
# Predicting probabilites
#
set.seed(1120)
pfo <- ranger(survived ~ ., data = t3, probability = TRUE) # probability forest
# Prediction wrapper
pfun_pfo <- function(object, newdata) {
predict(object, data = newdata)$predictions[, "yes"]
}
pfun_pfo(pfo, newdata = head(t3))
#> [1] 0.9383245 0.8597809 0.6351732 0.3805831 0.7631647 0.3235726
# Use default event level; should throw a warning message
set.seed(1125)
vi_permute(
pfo,
train = t3,
target = "survived",
pred_wrapper = pfun_pfo,
metric = "roc_auc",
nsim = 10
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass -0.103 0.00455
#> 2 age -0.0986 0.00678
#> 3 sex -0.238 0.0124
#> 4 sibsp -0.0336 0.00339
#> 5 parch -0.0226 0.00159
# Change the event level
set.seed(1125)
vi_permute(
pfo,
train = t3,
target = "survived",
pred_wrapper = pfun_pfo,
metric = "roc_auc",
event_level = "second",
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.103 0.00455
#> 2 age 0.0986 0.00678
#> 3 sex 0.238 0.0124
#> 4 sibsp 0.0336 0.00339
#> 5 parch 0.0226 0.00159
# Could also do this with a wrapper function
mfun <- function(truth, estimate) {
roc_auc_vec(truth = truth, estimate = estimate, event_level = "second")
}
set.seed(1125)
vi_permute(
pfo,
train = t3,
target = "survived",
pred_wrapper = pfun_pfo,
metric = mfun,
smaller_is_better = FALSE,
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.103 0.00455
#> 2 age 0.0986 0.00678
#> 3 sex 0.238 0.0124
#> 4 sibsp 0.0336 0.00339
#> 5 parch 0.0226 0.00159 Created on 2023-05-08 with reprex v2.0.2 |
In devel now and will be part of next release. |
If you were good with a
yardstick
dependency, we could expand the list of metrics that can be used.The text was updated successfully, but these errors were encountered: