-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VI Plot for each class in a multi-class classification task #156
Comments
Hi @viv-analytics, this is not possible in vip. In fact, I don't even think xgboost returns (or measures) importance scores on a per class level. You could potentially role your own solution using permutation importance, but you would have to run it for each class and tie it to a reasonable metric. Alternatively, you could use the prediction contributions from XGBoost. These can be aggregated into global importance scores by taking the mean absolute value, and XGBoost does indeed provide contributions for each class. Here's a quick and dirty example below: # Using raw XGBoost model, so need matrix of features for predictions
X <- data.matrix(subset(penguins_train, select = -species))
bst <- extract_fit_parsnip(pens_fit)$fit # underlying model
# Will return a list with one element for each class; each element will have the
# same dimension as the features in the training data (namely, 258 x 8)
ex <- predict(bst, newdata = X, reshape = TRUE, predcontrib = TRUE)
# Compute the mean absolute value for each column and return a tibble like you
# would get from the vip package (for convenience plotting)
agg <- function(x) {
res <- apply(x, MARGIN = 2, FUN = function(y) {
mean(abs(y))
})
res <- data.frame("Variable" = names(res), "Importance" = res)
res <- tibble::as_tibble(res)
class(res) <- c("vi", class(res))
res
}
# Compute SHAP-based VI for each class
agg(ex[[1]]) # |> vip::vip() # class 1
agg(ex[[2]]) # |> vip::vip() # class 2
agg(ex[[3]]) # |> vip::vip() # class 3 You might even consider rolling this all into one function, say |
Thanks a lot @bgreenwell. I've rolled that into a function using an ID column and it works smoothly. predict() induces columns called BIAS. How should this be interpreted in that context? |
Hi @viv-analytics apologies for missing this. For your purposes you can ignore this value (and it should be constant). I’m most cases it represents the average training prediction and is useful in interpreting the individual explanations. It usually plays no role in the aggregated importance scores! |
Hi @bgreenwell, @topepo
Would it be possible to get a feature importance plot for each class in a multi-class classification task using VI?
This will only show the feature importance on a model-level but not an class-level.
I highly appreciate any hint
The text was updated successfully, but these errors were encountered: