Skip to content

Commit

Permalink
Merge pull request #107 from ModelOriented/stacked_bar
Browse files Browse the repository at this point in the history
Stacked bar
  • Loading branch information
mayer79 authored Oct 13, 2023
2 parents 206e533 + da511ad commit 067fa82
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 86 deletions.
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# shapviz 0.9.2

## User-visible changes

- `sv_importance()` of a "mshapviz" object now returns a dodged barplot instead of separate barplots via {patchwork}. Use the new argument `bar_type` to switch to a stacked barplot (`bar_type = "stack"`), to "facets" (via {ggplot2}), or "separate" for the old behaviour.

## New features

- Added connector to [permshap](https://github.com/mayer79/permshap), a package calculating permutation SHAP values for regression and (probabilistic) classification.

## Other changes

- Revised vignette on "mshapviz".

# shapviz 0.9.1

## New features
Expand Down
2 changes: 1 addition & 1 deletion R/shapviz-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @importFrom xgboost xgb.train

globalVariables(c("from", "i", "id", "label", "to", "x", "shap", "SHAP",
"feature", "value", "color", "Var2", "Var3", "S"))
"feature", "value", "color", "Var2", "Var3", "S", "ind", "values"))

.onLoad <- function(libname, pkgname) {
op <- options()
Expand Down
60 changes: 57 additions & 3 deletions R/sv_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#' Set to `Inf` to show all features. Has no effect if `kind = "no"`.
#' @param fill Color used to fill the bars (only used if bars are shown).
#' @param bar_width Relative width of the bars (only used if bars are shown).
#' @param bar_type For "mshapviz" objects with `kind = "bar"`: How should bars be
#' represented? The default is "dodge" for dodged bars. Other options are "stack",
#' "wrap", or "separate" (via {patchwork}). Note that "separate" is currently
#' the only option that supports `show_numbers = TRUE`.
#' @param bee_width Relative width of the beeswarms.
#' @param bee_adjust Relative bandwidth adjustment factor used in
#' estimating the density of the beeswarms.
Expand All @@ -40,7 +44,7 @@
#' @returns
#' A "ggplot" (or "patchwork") object representing an importance plot, or - if
#' `kind = "no"` - a named numeric vector of sorted SHAP feature importances
#' (or a list of such vectors in case of an object of class "mshapviz").
#' (or a matrix in case of an object of class "mshapviz").
#' @examples
#' \dontrun{
#' X_train <- data.matrix(iris[, -1])
Expand Down Expand Up @@ -154,14 +158,59 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n
#' SHAP importance plot for an object of class "mshapviz".
#' @export
sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "no"),
max_display = 15L, fill = "#fca50a", bar_width = 2/3,
max_display = 15L, fill = "#fca50a",
bar_width = 2/3,
bar_type = c("dodge", "stack", "facets", "separate"),
bee_width = 0.4, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Feature value",
show_numbers = FALSE, format_fun = format_max,
number_size = 3.2, ...) {
kind <- match.arg(kind)
bar_type <- match.arg(bar_type)

# All other cases are done via {patchwork}
if (kind %in% c("bar", "no") && bar_type != "separate") {
imp <- .get_imp(get_shap_values(object))
if (kind == "no") {
return(imp)
}
if (nrow(imp) > max_display) {
imp <- imp[seq_len(max_display), , drop = FALSE]
}
ord <- rownames(imp)
imp_df <- data.frame(
feature = factor(ord, rev(ord)), utils::stack(as.data.frame(imp))
)

if (bar_type %in% c("dodge", "stack")) {
imp_df <- transform(imp_df, ind = factor(ind, rev(levels(ind))))
if (is.null(viridis_args)) {
viridis_args <- list()
}
p <- ggplot2::ggplot(imp_df, ggplot2::aes(x = values, y = feature)) +
ggplot2::geom_bar(
ggplot2::aes(fill = ind),
width = bar_width,
stat = "identity",
position = bar_type,
...
) +
ggplot2::labs(fill = ggplot2::element_blank()) +
do.call(ggplot2::scale_fill_viridis_d, viridis_args) +
ggplot2::guides(fill = ggplot2::guide_legend(reverse = TRUE))
} else { # facets
p <- ggplot2::ggplot(imp_df, ggplot2::aes(x = values, y = feature)) +
ggplot2::geom_bar(fill = fill, width = bar_width, stat = "identity", ...) +
ggplot2::facet_wrap("ind")
}
p <- p +
ggplot2::xlab("mean(|SHAP value|)") +
ggplot2::ylab(ggplot2::element_blank())
return(p)
}

# Now, patchwork
plot_list <- lapply(
object,
FUN = sv_importance,
Expand Down Expand Up @@ -198,7 +247,12 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
}

.get_imp <- function(z) {
sort(colMeans(abs(z)), decreasing = TRUE)
if (is.matrix(z)) {
return(sort(colMeans(abs(z)), decreasing = TRUE))
}
# list/mshapviz
imp <- sapply(z, function(x) colMeans(abs(x)))
imp[order(-rowSums(imp)), ]
}

.scale_X <- function(X) {
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ To further simplify the use of {shapviz}, we added direct connectors to:
- [`kernelshap`](https://CRAN.R-project.org/package=kernelshap)
- [`fastshap`](https://CRAN.R-project.org/package=fastshap)
- [`shapr`](https://CRAN.R-project.org/package=shapr)
- [`treeshap`](https://github.com/ModelOriented/treeshap/)
- [`treeshap`](https://CRAN.R-project.org/package=treeshap)
- [`DALEX`](https://CRAN.R-project.org/package=DALEX)
- [`permshap`](https://github.com/mayer79/permshap)

Expand Down
Binary file modified man/figures/VIGNETTE-dep-ranger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion man/sv_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 18 additions & 8 deletions tests/testthat/test-plots-mshapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ test_that("plots work for basic example", {
suppressMessages(expect_s3_class(sv_waterfall(x, 2:3), "patchwork"))
expect_s3_class(sv_force(x, 2), "patchwork")
suppressMessages(expect_s3_class(sv_force(x, 2:3), "patchwork"))
expect_s3_class(sv_importance(x), "patchwork")
expect_s3_class(sv_importance(x, show_numbers = TRUE), "patchwork")
expect_s3_class(sv_importance(x), "ggplot")
expect_s3_class(sv_importance(x, bar_type = "stack"), "ggplot")
expect_s3_class(sv_importance(x, bar_type = "facets"), "ggplot")
expect_s3_class(
sv_importance(x, show_numbers = TRUE, bar_type = "separate"), "patchwork"
)
expect_s3_class(sv_importance(x, kind = "beeswarm"), "patchwork")
expect_s3_class(sv_dependence(x, "Petal.Length"), "patchwork")
expect_s3_class(sv_dependence2D(x, x = "Petal.Length", y = "Species"), "patchwork")
Expand All @@ -20,8 +24,12 @@ test_that("using 'max_display' gives no error", {
suppressMessages(expect_s3_class(sv_waterfall(x, 2:10, max_display = 2L), "patchwork"))
expect_s3_class(sv_force(x, 2, max_display = 2L), "patchwork")
suppressMessages(expect_s3_class(sv_force(x, 2:10, max_display = 2L), "patchwork"))
expect_s3_class(sv_importance(x, max_display = 2L), "patchwork")
expect_s3_class(sv_importance(x, max_display = 2L, show_numbers = TRUE), "patchwork")
expect_s3_class(sv_importance(x, max_display = 2L), "ggplot")
expect_s3_class(sv_importance(x, max_display = 2L, bar_type = "stack"), "ggplot")
expect_s3_class(sv_importance(x, max_display = 2L, bar_type = "facets"), "ggplot")
expect_s3_class(
sv_importance(x, max_display = 2L, show_numbers = TRUE, bar_type = "separate"), "patchwork"
)
})

# SHAP interactions
Expand Down Expand Up @@ -71,8 +79,10 @@ x <- c(m1 = x, m2 = x)
test_that("plots work for non-syntactic column names", {
expect_s3_class(sv_waterfall(x, 2), "patchwork")
expect_s3_class(sv_force(x, 2), "patchwork")
expect_s3_class(sv_importance(x), "patchwork")
expect_s3_class(sv_importance(x, show_numbers = TRUE), "patchwork")
expect_s3_class(sv_importance(x), "ggplot")
expect_s3_class(
sv_importance(x, bar_type = "separate", show_numbers = TRUE), "patchwork"
)
expect_s3_class(sv_importance(x, max_display = 2, kind = "beeswarm"), "patchwork")
expect_s3_class(sv_importance(x, kind = "beeswarm"), "patchwork")
expect_s3_class(sv_dependence(x, "strange name"), "patchwork")
Expand All @@ -84,15 +94,15 @@ test_that("plots work for non-syntactic column names", {
)
})

test_that("sv_importance() and sv_interaction() and kind = 'no' gives list", {
test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L])
fit <- xgboost::xgb.train(data = dtrain, nrounds = 50L, nthread = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
x <- c(m1 = x, m2 = x)

imp <- sv_importance(x, kind = "no")
expect_true(is.list(imp) && length(imp) == length(x))
expect_true(is.matrix(imp) && all(dim(imp) == c(4L, length(x))))

inter <- sv_interaction(x, kind = "no")
expect_true(is.list(inter) && all(dim(inter[[1L]]) == rep(ncol(X_pred), 2L)))
Expand Down
3 changes: 1 addition & 2 deletions vignettes/basic_use.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ To further simplify the use of {shapviz}, we added direct connectors to:
- [`kernelshap`](https://CRAN.R-project.org/package=kernelshap)
- [`fastshap`](https://CRAN.R-project.org/package=fastshap)
- [`shapr`](https://CRAN.R-project.org/package=shapr)
- [`treeshap`](https://github.com/ModelOriented/treeshap/) (not on CRAN)
- [`treeshap`](https://CRAN.R-project.org/package=treeshap)
- [`DALEX`](https://CRAN.R-project.org/package=DALEX)
- [`permshap`](https://github.com/mayer79/permshap) (not on CRAN)

Expand Down Expand Up @@ -413,4 +413,3 @@ sv_dependence(shp, "clarity", alpha = 0.2, size = 1)
```

## References

6 changes: 3 additions & 3 deletions vignettes/geographic.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ $$
$$
Like any feature, the effect of a single geographic feature $X^{\textrm{geo}, j}$ can be described using SHAP dependence plots. However, studying the effect of latitude (or any other location dependent feature) alone is often not very illuminating - simply due to strong interaction effects and correlations with other geographic features.

That's where the additivity of SHAP values comes into play: The sum of SHAP values of all geographic components represent the total effect of $X^\textrm{geo}$, and this sum can be visualized as a heatmap or 3D scatterplot against latitude/longitude (or any other geographic representation)
That's where the additivity of SHAP values comes into play: The sum of SHAP values of all geographic components represent the total effect of $X^\textrm{geo}$, and this sum can be visualized as a heatmap or 3D scatterplot against latitude/longitude (or any other geographic representation).

## A first example

For illustration, we will use a beautiful house price dataset containing information on about 14'000 houses sold in 2016 in Miami-Dade County. Some of the columns are as follows:

- **SALE_PRC**: Sale price in USD: Its logarithm will be our model response.
- **SALE_PRC**: Sale price in USD: Its logarithm will be our model **response**.
- *LATITUDE*, *LONGITUDE*: Coordinates
- *CNTR_DIST*: Distance to central business district
- *OCEAN_DIST*: Distance (ft) to the ocean
Expand Down Expand Up @@ -158,6 +158,6 @@ sv_dependence2D(sv2, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
coord_equal()
```

Again, the resulting total geographic effect looks reasonable. Note that, unlike in the first example, there are no interactions to non-geographic components, leading to a Ceteris Paribus interpretation.
Again, the resulting total geographic effect looks reasonable. Note that, unlike in the first example, there are no interactions to non-geographic components, leading to a Ceteris Paribus interpretation. Furthermore, it contains the effect of the other regional features.

## References
Loading

0 comments on commit 067fa82

Please sign in to comment.