Skip to content

Commit

Permalink
Make model agnostic SHAP for H2O more visible.
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Jan 10, 2025
1 parent a80aee9 commit b95dd29
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 47 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ jobs:
"shapviz\\.shapr",
"shapviz\\.kernelshap",
"shapviz\\.permshap",
"shapviz\\.H2ORegressionModel",
"shapviz\\.H2OBinomialModel",
"shapviz\\.H2OModel",
"\\.onLoad"
)
Expand Down
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ S3method(print,mshapviz)
S3method(print,shapviz)
S3method(rbind,mshapviz)
S3method(rbind,shapviz)
S3method(shapviz,H2OBinomialModel)
S3method(shapviz,H2OModel)
S3method(shapviz,H2ORegressionModel)
S3method(shapviz,default)
S3method(shapviz,explain)
S3method(shapviz,kernelshap)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

### Documentation

- H2O random forests (regression and binary classification) are now supported as well (fast TreeSHAP) [#163](https://github.com/ModelOriented/shapviz/pull/163).
- H2O now supports passing background data for model agnostic SHAP. This is now easier visible in {shapviz}, see https://github.com/h2oai/h2o-3/issues/16463.
- H2O random forests (regression and binary classification) now support TreeSHAP as well [#163](https://github.com/ModelOriented/shapviz/pull/163).

### Compatibility

Expand Down
51 changes: 29 additions & 22 deletions R/shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' from a fitted model of type
#' - XGBoost,
#' - LightGBM, or
#' - H2O (tree-based models).
#' - H2O.
#'
#' Furthermore, [shapviz()] can digest the results of
#' - `fastshap::explain()`,
Expand Down Expand Up @@ -454,28 +454,26 @@ shapviz.kernelshap <- function(
}

#' @describeIn shapviz
#' Creates a "shapviz" object from a (tree-based) H2O regression model.
#' @export
shapviz.H2ORegressionModel = function(
object, X_pred, X = as.data.frame(X_pred), collapse = NULL, ...
) {
shapviz.H2OModel(object = object, X_pred = X_pred, X = X, collapse = collapse, ...)
}

#' @describeIn shapviz
#' Creates a "shapviz" object from a (tree-based) H2O binary classification model.
#' @export
shapviz.H2OBinomialModel = function(
object, X_pred, X = as.data.frame(X_pred), collapse = NULL, ...
) {
shapviz.H2OModel(object = object, X_pred = X_pred, X = X, collapse = collapse, ...)
}

#' @describeIn shapviz
#' Creates a "shapviz" object from a (tree-based) H2O model (base class).
#' Creates a "shapviz" object from an H2O model.
#' @param background_frame Background dataset for baseline SHAP or marginal SHAP.
#' Only for H2O models.
#' @param output_space If model has link function, this argument controls whether the
#' SHAP values should be linearly (= approximately) transformed to the original scale
#' (if `TRUE`). The default is to return the values on link scale.
#' Only for H2O models.
#' @param output_per_reference Switches between different algorithms, see
#' `?h2o::h2o.predict_contributions` for details.
#' Only for H2O models.
#' @export
shapviz.H2OModel = function(
object, X_pred, X = as.data.frame(X_pred), collapse = NULL, ...
object,
X_pred,
X = as.data.frame(X_pred),
collapse = NULL,
background_frame = NULL,
output_space = FALSE,
output_per_reference = FALSE,
...
) {
if (!requireNamespace("h2o", quietly = TRUE)) {
stop("Package 'h2o' not installed")
Expand All @@ -488,7 +486,16 @@ shapviz.H2OModel = function(
if (!inherits(X_pred, "H2OFrame")) {
X_pred <- h2o::as.h2o(X_pred)
}
S <- as.matrix(h2o::h2o.predict_contributions(object, newdata = X_pred, ...))
S <- as.matrix(
h2o::h2o.predict_contributions(
object,
newdata = X_pred,
background_frame = background_frame,
output_space = output_space,
output_per_reference = output_per_reference,
...
)
)
shapviz.matrix(
object = S[, setdiff(colnames(S), "BiasTerm"), drop = FALSE],
X = X,
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

SHAP and feature values are stored in a "shapviz" object that is built from:

1. Models that know how to calculate SHAP values: XGBoost, LightGBM, H2O (tree-based models).
1. Models that know how to calculate SHAP values: XGBoost, LightGBM, and H2O.
2. SHAP crunchers like {fastshap}, {kernelshap}, {treeshap}, {fastr}, and {DALEX}.
3. SHAP matrix and corresponding feature values.

Expand Down
37 changes: 24 additions & 13 deletions man/shapviz.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 14 additions & 6 deletions vignettes/basic_use.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ In particular, the following plots are available:

SHAP and feature values are stored in a "shapviz" object that is built from:

1. Models that know how to calculate SHAP values: XGBoost, LightGBM, h2o (boosted trees).
1. Models that know how to calculate SHAP values: XGBoost, LightGBM, and h2o.
2. SHAP crunchers like {fastshap}, {kernelshap}, {treeshap}, {fastr}, and {DALEX}.
3. SHAP matrix and corresponding feature values.

Expand Down Expand Up @@ -178,7 +178,7 @@ sv_dependence(shp, "Sepal.Width")

### H2O

If you work with a boosted trees H2O model:
H2O supports TreeSHAP for boosted trees and random forests. For other models, model agnostic method based on marginal expectations are used, requiring a background dataset.

```r
library(shapviz)
Expand All @@ -187,10 +187,18 @@ library(h2o)
h2o.init()

iris2 <- as.h2o(iris)
fit <- h2o.gbm(colnames(iris[-1]), "Sepal.Length", training_frame = iris2)
shp <- shapviz(fit, X_pred = iris)
sv_force(shp, row_id = 1)
sv_dependence(shp, "Species")

# Random forest
fit_rf <- h2o.randomForest(colnames(iris[-1]), "Sepal.Length", training_frame = iris2)
shp_rf <- shapviz(fit_rf, X_pred = iris)
sv_force(shp_rf, row_id = 1)
sv_dependence(shp_rf, "Species")

# Linear model
fit_lm <- h2o.glm(colnames(iris[-1]), "Sepal.Length", training_frame = iris2)
shp_lm <- shapviz(fit_lm, X_pred = iris, background_frame = iris2)
sv_force(shp_lm, row_id = 1)
sv_dependence(shp_lm, "Species")
```

### treeshap
Expand Down

0 comments on commit b95dd29

Please sign in to comment.