Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VI Plot for each class in a multi-class classification task #156

Closed
viv-analytics opened this issue Sep 27, 2023 · 3 comments
Closed

VI Plot for each class in a multi-class classification task #156

viv-analytics opened this issue Sep 27, 2023 · 3 comments

Comments

@viv-analytics
Copy link

viv-analytics commented Sep 27, 2023

Hi @bgreenwell, @topepo

Would it be possible to get a feature importance plot for each class in a multi-class classification task using VI?


library(tidymodels)
library(palmerpenguins)
library(patchwork)


conflicted::conflicts_prefer(palmerpenguins::penguins_raw)

penguins |>
  group_by(species) |>
  count()

penguins_e <-
  penguins |>
  mutate(across(where(is.integer), as.double))

penguins_edited <-
  missRanger::missRanger(data = penguins_e,
                         seed = 1234,
                         verbose = T)

set.seed(123)
penguins_split <- initial_split(penguins_edited, strata = "species")
penguins_train <- training(penguins_split)

pens_rec <-
  recipe(species ~ ., data = penguins_train) |>
  step_impute_bag(seed_val = 1234) |>
  step_integer(all_nominal_predictors(), zero_based = TRUE)

boost_tree_xgboost_spec <-
  boost_tree(
    tree_depth = tune(),
    trees = tune(),
    learn_rate = tune(),
    min_n = tune(),
    stop_iter = 20
  ) %>%
  set_engine('xgboost', importance = TRUE, nthread = 4) %>%
  set_mode('classification')


set.seed(2)
pens_rs <-
  vfold_cv(penguins_train, strata = "species", v = 5)

ctrl <-
  control_grid(
    save_workflow = TRUE,
    verbose = TRUE,
    save_pred = TRUE,
    allow_par = T
  )

pens_wf <-
  workflow(preprocessor = pens_rec, spec = boost_tree_xgboost_spec)

set.seed(1234)
pens_tune_res <-
  tune_grid(
    object = pens_wf,
    resamples = pens_rs,
    grid = 10,
    control = ctrl
  )


pens_fit <-
  fit_best(pens_tune_res, verbose = TRUE)

pens_final_fit <-
  last_fit(pens_fit, split = penguins_split)

pens_final_fit |> extract_fit_parsnip() |> vip::vip()

This will only show the feature importance on a model-level but not an class-level.

I highly appreciate any hint

@bgreenwell
Copy link
Member

bgreenwell commented Oct 3, 2023

Hi @viv-analytics, this is not possible in vip. In fact, I don't even think xgboost returns (or measures) importance scores on a per class level.

You could potentially role your own solution using permutation importance, but you would have to run it for each class and tie it to a reasonable metric.

Alternatively, you could use the prediction contributions from XGBoost. These can be aggregated into global importance scores by taking the mean absolute value, and XGBoost does indeed provide contributions for each class. Here's a quick and dirty example below:

# Using raw XGBoost model, so need matrix of features for predictions
X <- data.matrix(subset(penguins_train, select = -species))
bst <- extract_fit_parsnip(pens_fit)$fit  # underlying model

# Will return a list with one element for each class; each element will have the
# same dimension as the features in the training data (namely, 258 x 8)
ex <- predict(bst, newdata = X, reshape = TRUE, predcontrib = TRUE)

# Compute the mean absolute value for each column and return a tibble like you
# would get from the vip package (for convenience plotting)
agg <- function(x) {  
  res <- apply(x, MARGIN = 2, FUN = function(y) {
    mean(abs(y))
  })
  res <- data.frame("Variable" = names(res), "Importance" = res)
  res <- tibble::as_tibble(res)
  class(res) <- c("vi", class(res))
  res
}

# Compute SHAP-based VI for each class
agg(ex[[1]]) # |> vip::vip()  # class 1
agg(ex[[2]]) # |> vip::vip()  # class 2
agg(ex[[3]]) # |> vip::vip()  # class 3

You might even consider rolling this all into one function, say shap_vi_per_class(), and have it return a list of tibbles for plotting, or even stack them together with a class ID for easier plotting.

@viv-analytics
Copy link
Author

viv-analytics commented Oct 10, 2023

Thanks a lot @bgreenwell.

I've rolled that into a function using an ID column and it works smoothly.

predict() induces columns called BIAS. How should this be interpreted in that context?

@bgreenwell
Copy link
Member

Hi @viv-analytics apologies for missing this. For your purposes you can ignore this value (and it should be constant). I’m most cases it represents the average training prediction and is useful in interpreting the individual explanations. It usually plays no role in the aggregated importance scores!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants