Skip to content

Commit

Permalink
sv_importance and sv_interaction receive a sort_features option.
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Feb 7, 2024
1 parent 2713bda commit 163ba46
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 23 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: shapviz
Title: SHAP Visualizations
Version: 0.9.3
Version: 0.9.4
Authors@R: c(
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre")),
person("Adrian", "Stando", , "adrian.j.stando@gmail.com", role = "ctb")
Expand All @@ -21,7 +21,7 @@ Depends:
R (>= 3.6.0)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Imports:
ggfittext (>= 0.8.0),
gggenes,
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# shapviz 0.9.4

## Improvements

- New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136.

# shapviz 0.9.3

## `sv_dependence()`: Control over automatic color feature selection
Expand Down
25 changes: 17 additions & 8 deletions R/sv_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' @param kind Should a "bar" plot (the default), a "beeswarm" plot, or "both" be shown?
#' Set to "no" in order to suppress plotting. In that case, the sorted
#' SHAP feature importances of all variables are returned.
#' @param max_display Maximum number of features (with highest importance) to plot.
#' @param max_display How many features should be plotted?
#' Set to `Inf` to show all features. Has no effect if `kind = "no"`.
#' @param fill Color used to fill the bars (only used if bars are shown).
#' @param bar_width Relative width of the bars (only used if bars are shown).
Expand All @@ -38,6 +38,7 @@
#' (only if `show_numbers = TRUE`). To change to scientific notation, use
#' `function(x) = prettyNum(x, scientific = TRUE)`.
#' @param number_size Text size of the numbers (if `show_numbers = TRUE`).
#' @param sort_features Should features be sorted or not? The default is `TRUE`.
#' @param ... Arguments passed to [ggplot2::geom_bar()] (if `kind = "bar"`) or to
#' [ggplot2::geom_point()] otherwise. For instance, passing `alpha = 0.2` will produce
#' semi-transparent beeswarms, and setting `size = 3` will produce larger dots.
Expand Down Expand Up @@ -75,10 +76,10 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Feature value",
show_numbers = FALSE, format_fun = format_max,
number_size = 3.2, ...) {
number_size = 3.2, sort_features = TRUE, ...) {
stopifnot("format_fun must be a function" = is.function(format_fun))
kind <- match.arg(kind)
imp <- .get_imp(get_shap_values(object))
imp <- .get_imp(get_shap_values(object), sort_features = sort_features)

if (kind == "no") {
return(imp)
Expand Down Expand Up @@ -162,13 +163,13 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Feature value",
show_numbers = FALSE, format_fun = format_max,
number_size = 3.2, ...) {
number_size = 3.2, sort_features = TRUE, ...) {
kind <- match.arg(kind)
bar_type <- match.arg(bar_type)

# All other cases are done via {patchwork}
if (kind %in% c("bar", "no") && bar_type != "separate") {
imp <- .get_imp(get_shap_values(object))
imp <- .get_imp(get_shap_values(object), sort_features = sort_features)
if (kind == "no") {
return(imp)
}
Expand Down Expand Up @@ -223,6 +224,7 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
show_numbers = show_numbers,
format_fun = format_fun,
number_size = number_size,
sort_features = sort_features,
...
)
if (kind == "no") {
Expand All @@ -243,13 +245,20 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
(z - r[1L]) /(r[2L] - r[1L])
}

.get_imp <- function(z) {
.get_imp <- function(z, sort_features = TRUE) {
if (is.matrix(z)) {
return(sort(colMeans(abs(z)), decreasing = TRUE))
imp <- colMeans(abs(z))
if (sort_features) {
imp <- sort(imp, decreasing = TRUE)
}
return(imp)
}
# list/mshapviz
imp <- sapply(z, function(x) colMeans(abs(x)))
imp[order(-rowSums(imp)), ]
if (sort_features) {
imp <- imp[order(-rowSums(imp)), ]
}
return(imp)
}

.scale_X <- function(X) {
Expand Down
9 changes: 6 additions & 3 deletions R/sv_interaction.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
max_display = 7L, alpha = 0.3,
bee_width = 0.3, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value", ...) {
color_bar_title = "Row feature value",
sort_features = TRUE, ...) {
kind <- match.arg(kind)
if (is.null(get_shap_interactions(object))) {
stop("No SHAP interaction values available.")
}
ord <- names(.get_imp(get_shap_values(object)))
ord <- names(.get_imp(get_shap_values(object), sort_features = sort_features))
object <- object[, ord]

if (kind == "no") {
Expand Down Expand Up @@ -112,7 +113,8 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
max_display = 7L, alpha = 0.3,
bee_width = 0.3, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value", ...) {
color_bar_title = "Row feature value",
sort_features = TRUE, ...) {
kind <- match.arg(kind)

plot_list <- lapply(
Expand All @@ -126,6 +128,7 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
bee_adjust = bee_adjust,
viridis_args = viridis_args,
color_bar_title = color_bar_title,
sort_features = sort_features,
...
)
if (kind == "no") {
Expand Down
2 changes: 1 addition & 1 deletion man/shapviz-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/sv_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/sv_interaction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "SHAP Visualizations",
Version = "0.9.3",
Version = "0.9.4",
Description = "Visualizations for SHAP (SHapley Additive exPlanations),
such as waterfall plots, force plots, various types of importance plots,
dependence plots, and interaction plots.
Expand Down
23 changes: 17 additions & 6 deletions tests/testthat/test-plots-mshapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,31 @@ test_that("plots work for non-syntactic column names", {
)
})

test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
x <- c(m1 = x, m2 = x)
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
x <- c(m1 = x, m2 = x)

test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
imp <- sv_importance(x, kind = "no")
expect_true(is.matrix(imp) && all(dim(imp) == c(4L, length(x))))

inter <- sv_interaction(x, kind = "no")
expect_true(is.list(inter) && all(dim(inter[[1L]]) == rep(ncol(X_pred), 2L)))
})


test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", {
imp <- sv_importance(x, kind = "no", sort_features = FALSE)
expect_true(all(rownames(imp) == colnames(x$m1)))

inter <- sv_interaction(x, kind = "no", sort_features = FALSE)
expect_true(all(rownames(inter$m1) == colnames(x$m1)))
})



test_that("sv_dependence() does not work with multiple v", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/test-plots-shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,16 @@ test_that("sv_importance() and sv_interaction() and kind = 'no' gives numeric ou
expect_true(is.numeric(inter) && all(dim(inter) == rep(ncol(X_pred), 2L)))
})

test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)

imp <- sv_importance(x, kind = "no", sort_features = FALSE)
expect_true(all(names(imp) == colnames(x)))

inter <- sv_interaction(x, kind = "no", sort_features = FALSE)
expect_true(all(names(inter) == colnames(x)))
})

0 comments on commit 163ba46

Please sign in to comment.