diff --git a/DESCRIPTION b/DESCRIPTION index 01b3ad6..eceee63 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,7 +21,7 @@ Depends: R (>= 3.6.0) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Imports: ggfittext (>= 0.8.0), gggenes, diff --git a/NEWS.md b/NEWS.md index a3bed80..766c7e7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # shapviz 0.9.4 +## API improvements + +- Support both XGBoost 1.x.x as well as XGBoost 2.x.x, implemented in #144. + ## Improvements - New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136. diff --git a/R/shapviz.R b/R/shapviz.R index d7f087c..8ef9235 100644 --- a/R/shapviz.R +++ b/R/shapviz.R @@ -201,36 +201,46 @@ shapviz.xgb.Booster = function(object, X_pred, X = X_pred, which_class = NULL, ) } - # Handle problem that S and S_inter lack a dimension if X_pred has only one row - # This might be fixed later directly in XGBoost. - if (nrow(X_pred) == 1L) { - if (is.list(S)) { # multiclass - S <- lapply(S, rbind) + if (utils::packageVersion("xgboost") >= "2") { + # Turn result of multi-output model into list of lower dim arrays + if (length(dim(S)) == 3L) { + S <- asplit(S, MARGIN = 2L) if (interactions) { - S_inter <- lapply(S_inter, .add_dim) + S_inter <- asplit(S_inter, MARGIN = 2L) } - } else { - S <- rbind(S) - if (interactions) { - S_inter <-.add_dim(S_inter) + } + } else { + # Handle problem that S and S_inter lack a dimension if X_pred has only one row + # This only applies to XGBoost < 2 + if (nrow(X_pred) == 1L) { + if (is.list(S)) { # multiclass + S <- lapply(S, rbind) + if (interactions) { + S_inter <- lapply(S_inter, .add_dim) + } + } else { + S <- rbind(S) + if (interactions) { + S_inter <-.add_dim(S_inter) + } } } } - # Multiclass + # Multi-class (or some other multi-output situation) if (is.list(S)) { if (is.null(which_class)) { - nms <- setdiff(colnames(S[[1L]]), "BIAS") + pp <- ncol(S[[1L]]) # = ncol(X_pred) + 1. The last column is the baseline if (interactions) { - S_inter <- lapply(S_inter, function(s) s[, nms, nms, drop = FALSE]) + S_inter <- lapply(S_inter, function(s) s[, -pp, -pp, drop = FALSE]) } else { # mapply() does not want to see a length 0 object like NULL S_inter <- replicate(length(S), NULL) } shapviz_list <- mapply( FUN = shapviz.matrix, - object = lapply(S, function(s) s[, nms, drop = FALSE]), - baseline = lapply(S, function(s) unname(s[1L, "BIAS"])), + object = lapply(S, function(s) s[, -pp, drop = FALSE]), + baseline = lapply(S, function(s) unname(s[1L, pp])), S_inter = S_inter, MoreArgs = list(X = X, collapse = collapse), SIMPLIFY = FALSE @@ -246,12 +256,12 @@ shapviz.xgb.Booster = function(object, X_pred, X = X_pred, which_class = NULL, } # Call matrix method - nms <- setdiff(colnames(S), "BIAS") + pp <- ncol(S) shapviz.matrix( - object = S[, nms, drop = FALSE], + object = S[, -pp, drop = FALSE], X = X, - baseline = unname(S[1L, "BIAS"]), - S_inter = if (interactions) S_inter[, nms, nms, drop = FALSE], + baseline = unname(S[1L, pp]), + S_inter = if (interactions) S_inter[, -pp, -pp, drop = FALSE], collapse = collapse ) } diff --git a/README.md b/README.md index 5345116..0d4adab 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ library(shapviz) library(ggplot2) library(xgboost) -set.seed(1) +set.seed(10) # Build model x <- c("carat", "cut", "color", "clarity") diff --git a/man/figures/README-dep.png b/man/figures/README-dep.png index f37cdd9..367b572 100644 Binary files a/man/figures/README-dep.png and b/man/figures/README-dep.png differ diff --git a/man/figures/README-force.svg b/man/figures/README-force.svg index 4857498..d35aefb 100644 --- a/man/figures/README-force.svg +++ b/man/figures/README-force.svg @@ -1,68 +1,320 @@ - - + + - - - - - - - - - + + + - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - -cut=Very Good -carat=0.4 -color=G -clarity=VS1 --3360 -+317 - - -E[f(x)]=3929 -f(x)=886 - - - - - -1000 -2000 -3000 -4000 -SHAP value + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/man/figures/README-imp.svg b/man/figures/README-imp.svg index 6a93567..0021a51 100644 --- a/man/figures/README-imp.svg +++ b/man/figures/README-imp.svg @@ -1,306 +1,294 @@ - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/man/figures/README-waterfall.svg b/man/figures/README-waterfall.svg index f9821ee..7d8dbb6 100644 --- a/man/figures/README-waterfall.svg +++ b/man/figures/README-waterfall.svg @@ -1,59 +1,319 @@ - - + + - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - -+317 --3360 - - - - - - -f(x)=886 -E[f(x)]=3929 -cut = Very Good -color = G -clarity = VS1 -carat = 0.4 - - - - - -1000 -2000 -3000 -4000 -SHAP value + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/man/figures/VIGNETTE-dep-ranger.png b/man/figures/VIGNETTE-dep-ranger.png index d083d8a..0153174 100644 Binary files a/man/figures/VIGNETTE-dep-ranger.png and b/man/figures/VIGNETTE-dep-ranger.png differ diff --git a/man/figures/VIGNETTE-dep.png b/man/figures/VIGNETTE-dep.png index 065e6b4..1eb5e12 100644 Binary files a/man/figures/VIGNETTE-dep.png and b/man/figures/VIGNETTE-dep.png differ diff --git a/vignettes/multiple_output.Rmd b/vignettes/multiple_output.Rmd index d9d08fd..034000e 100644 --- a/vignettes/multiple_output.Rmd +++ b/vignettes/multiple_output.Rmd @@ -109,7 +109,7 @@ sv_importance(shp) sv_dependence(shp, v = "Sepal.Width") + plot_layout(ncol = 2) & - ylim(-0.03, 0.035) + ylim(-0.06, 0.06) ``` ![](../man/figures/VIGNETTE-dep-ranger.png)