Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pred_wrapper with multinomial classification #78

Closed
gagirob opened this issue Jul 19, 2024 · 2 comments
Closed

pred_wrapper with multinomial classification #78

gagirob opened this issue Jul 19, 2024 · 2 comments

Comments

@gagirob
Copy link

gagirob commented Jul 19, 2024

Hi @bgreenwell,

I am using "explain" with categorical dependent variables, predicted with random forests via nestcv.train . I found in one of nestedcv vignettes (https://cran.r-project.org/web/packages/nestedcv/vignettes/nestedcv_shap.html) that with multinomial classification I need to use pred_train_class1, pred_train_class2, etc as pred_wrapper. This works for pred_train_class1, pred_train_class2 and pred_train_class3, but not for pred_train_class4 and the following ones (I have 8 categories in total).

This is the code I am using:

ctrl <- trainControl(method = "cv", number = n_inner_folds, seeds = seeds, classProbs = TRUE, summaryFunction = mnLogLoss, allowParallel = F)
ncv_boruta <- nestedcv::nestcv.train(y = response, x = data, method = "rf", savePredictions = "final", n_outer_folds = n_outer_folds, outer_train_predict = T, n_inner_folds = n_inner_folds, filterFUN = boruta_filter, filter_options = list(select = c("Confirmed", "Tentative"), maxRuns = maxRuns), cv.cores = n_outer_folds, ntree = ntree, maximize = F, tuneGrid = tg, balance = sampling, trControl = ctrl)

nsim<-100
sh<-list()
set.seed(123)
sh[[1]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class1, nsim = nsim)
sh[[2]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class2, nsim = nsim)
sh[[3]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class3, nsim = nsim)
sh[[4]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class4, nsim = nsim)

And this is what I get:

sh[[4]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class4, nsim = nsim)
Error in explain.default(ncv_boruta, X = data, pred_wrapper = pred_train_class4, :
object 'pred_train_class4' not found

Thank you.

@myles-lewis
Copy link

Hi @gagirob,

I can answer the question about the missing pred_train_class4 object (I'm the author of the nestedcv package). The source code for pred_train_class3 is as follows:

pred_train_class3 <- function(x, newdata) {
  predict(x, newdata, type="prob")[,3]
}

I provided the first 3 classes as this is a common use case. It's straightforward to make pred_train_class4 as follows:

pred_train_class4 <- function(x, newdata) {
  predict(x, newdata, type="prob")[,4]
}

This way you can make the necessary prediction wrappers for your classes 5-8.

Bw, Myles

@bgreenwell
Copy link
Owner

Thanks for posting a response @myles-lewis!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants